mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 06:03:40 +08:00
[LoRA] Continue optimizing MoE LoRA weight loading (#29322)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
cf348c8d27
commit
2f5f9acd55
@ -28,12 +28,13 @@ def test_load_checkpoints(
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: list[str] = []
|
||||
expected_lora_lst: list[str] = []
|
||||
for module in BAICHUAN_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
expected_lora_lst.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
expected_lora_lst.append(module)
|
||||
expected_lora_modules = set(expected_lora_lst)
|
||||
if lora_name == "baichuan7B":
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
baichuan_lora_files, max_position_embeddings=4096
|
||||
@ -103,13 +104,13 @@ def test_lora_weights_mapping(baichuan_lora_files):
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: list[str] = []
|
||||
expected_lora_lst: list[str] = []
|
||||
for module in BAICHUAN_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
expected_lora_lst.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
|
||||
expected_lora_lst.append(module)
|
||||
expected_lora_modules = set(expected_lora_lst)
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.": "language_model.model.",
|
||||
|
||||
@ -26,13 +26,13 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
||||
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
|
||||
embedding_modules = LlamaForCausalLM.embedding_modules
|
||||
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: list[str] = []
|
||||
expected_lora_lst: list[str] = []
|
||||
for module in LLAMA_LORA_MODULES:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
expected_lora_lst.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
|
||||
expected_lora_lst.append(module)
|
||||
expected_lora_modules = set(expected_lora_lst)
|
||||
lora_path = get_adapter_absolute_path(lora_name)
|
||||
|
||||
# lora loading should work for either absolute path and huggingface id.
|
||||
|
||||
@ -60,7 +60,7 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -153,7 +153,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ColumnParallelLinear or (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
@ -272,7 +272,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
@ -338,7 +338,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
|
||||
|
||||
@ -396,7 +396,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
@ -434,7 +434,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
@ -480,7 +480,7 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
@ -516,7 +516,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
@ -565,7 +565,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
|
||||
@ -401,6 +401,61 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w13_lora_b_stacked[1][lora_id][experts_id]
|
||||
)
|
||||
|
||||
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
|
||||
"""
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w13_lora_a
|
||||
|
||||
# w13_lora_a shape (num_experts,rank,input_size)
|
||||
current_lora_rank = w13_lora_a.shape[1]
|
||||
assert current_lora_rank % self.tp_size == 0
|
||||
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
|
||||
sliced_rank = current_lora_rank // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_rank
|
||||
end_idx = (self.tp_rank + 1) * sliced_rank
|
||||
return w13_lora_a[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
|
||||
if self.tp_size == 1:
|
||||
return w13_lora_b
|
||||
|
||||
# w13_lora_b shape (num_experts,output_size,rank)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
return w13_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
|
||||
"""
|
||||
if self.tp_size == 1:
|
||||
return w2_lora_a
|
||||
# w2_lora_a shape (num_experts,rank,input_size)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
return w2_lora_a[:, :, start_idx:end_idx]
|
||||
|
||||
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
|
||||
"""
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w2_lora_b
|
||||
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
|
||||
# w2_lora_b shape (num_experts,output_size,rank)
|
||||
current_lora_size = w2_lora_b.shape[1]
|
||||
|
||||
sliced_size = current_lora_size // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_size
|
||||
end_idx = (self.tp_rank + 1) * sliced_size
|
||||
return w2_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
for pos in range(self._w13_slices):
|
||||
@ -411,6 +466,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.w2_lora_b_stacked[0][index] = 0
|
||||
self.adapter_enabled[index] = 0
|
||||
|
||||
#
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -418,69 +475,55 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
# Make mypy happy
|
||||
assert isinstance(lora_a, list)
|
||||
assert isinstance(lora_b, list)
|
||||
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
for eid in range(len(lora_a) // 3):
|
||||
w1_lora_a = lora_a[eid * 3]
|
||||
w2_lora_a = lora_a[eid * 3 + 1]
|
||||
w3_lora_a = lora_a[eid * 3 + 2]
|
||||
w1_lora_b = lora_b[eid * 3]
|
||||
w2_lora_b = lora_b[eid * 3 + 1]
|
||||
w3_lora_b = lora_b[eid * 3 + 2]
|
||||
|
||||
# Handle the case of adding LoRA to only a subset of experts
|
||||
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
|
||||
continue
|
||||
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
||||
|
||||
if self.tp_size > 1:
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
w1_lora_a, w2_lora_a, w3_lora_a = lora_a
|
||||
w1_lora_b, w2_lora_b, w3_lora_b = lora_b
|
||||
assert (
|
||||
num_experts
|
||||
== w1_lora_a.shape[0]
|
||||
== w2_lora_a.shape[0]
|
||||
== w3_lora_a.shape[0]
|
||||
)
|
||||
|
||||
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
|
||||
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
||||
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
||||
slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
|
||||
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
|
||||
slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
|
||||
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
|
||||
|
||||
if self.fully_sharded:
|
||||
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
|
||||
# and W2 B along the hidden_size dim.
|
||||
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
|
||||
w13_start_idx = self.tp_rank * w13_shard_size
|
||||
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
|
||||
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
||||
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
|
||||
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
|
||||
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
|
||||
|
||||
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
|
||||
w2_start_idx = self.tp_rank * w2_shard_size
|
||||
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
||||
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
||||
# w1 lora_a
|
||||
self.w13_lora_a_stacked[0][
|
||||
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
||||
].copy_(w1_lora_a, non_blocking=True)
|
||||
# w3 lora_a
|
||||
self.w13_lora_a_stacked[1][
|
||||
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
|
||||
].copy_(w3_lora_a, non_blocking=True)
|
||||
self.w13_lora_a_stacked[0][
|
||||
index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
|
||||
].copy_(slliced_w1_lora_a, non_blocking=True)
|
||||
|
||||
# w1 lora_b
|
||||
self.w13_lora_b_stacked[0][
|
||||
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
|
||||
].copy_(w1_lora_b, non_blocking=True)
|
||||
# w3 lora_b
|
||||
self.w13_lora_b_stacked[1][
|
||||
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
||||
].copy_(w3_lora_b, non_blocking=True)
|
||||
self.w13_lora_a_stacked[1][
|
||||
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
|
||||
].copy_(slliced_w3_lora_a, non_blocking=True)
|
||||
|
||||
self.w2_lora_a_stacked[0][
|
||||
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
||||
].copy_(w2_lora_a, non_blocking=True)
|
||||
self.w13_lora_b_stacked[0][
|
||||
index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
|
||||
].copy_(slliced_w1_lora_b, non_blocking=True)
|
||||
|
||||
self.w2_lora_b_stacked[0][
|
||||
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
||||
].copy_(w2_lora_b, non_blocking=True)
|
||||
self.w13_lora_b_stacked[1][
|
||||
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
|
||||
].copy_(slliced_w3_lora_b, non_blocking=True)
|
||||
|
||||
self.w2_lora_a_stacked[0][
|
||||
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
|
||||
].copy_(sliced_w2_lora_a, non_blocking=True)
|
||||
|
||||
self.w2_lora_b_stacked[0][
|
||||
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
|
||||
].copy_(sliced_w2_lora_b, non_blocking=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.base_layer.forward(*args, **kwargs)
|
||||
@ -506,12 +549,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
# return type(source_layer) is FusedMoE
|
||||
|
||||
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
|
||||
# source_layer is FusedMoE or SharedFusedMoE
|
||||
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
|
||||
|
||||
|
||||
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
@ -555,6 +598,9 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
|
||||
assert isinstance(model_config, PretrainedConfig)
|
||||
self._base_model = model_config.architectures[0]
|
||||
self.max_loras = lora_config.max_loras
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
@ -565,20 +611,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
self._create_lora_a_weights(max_loras, lora_config)
|
||||
self._create_lora_b_weights(max_loras, lora_config)
|
||||
|
||||
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w13_lora_a
|
||||
|
||||
# w13_lora_a shape (num_experts,rank,input_size)
|
||||
current_lora_rank = w13_lora_a.shape[1]
|
||||
assert current_lora_rank % self.tp_size == 0
|
||||
|
||||
sliced_rank = current_lora_rank // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_rank
|
||||
end_idx = (self.tp_rank + 1) * sliced_rank
|
||||
return w13_lora_a[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
|
||||
if self.tp_size == 1:
|
||||
return w13_lora_b
|
||||
|
||||
@ -586,7 +619,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
if is_interleave:
|
||||
# HACK: Currently, only GPT-OSS is in interleaved order
|
||||
if self._base_model == "GptOssForCausalLM":
|
||||
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
|
||||
# in the interleaved order, and corresponding LoRA need to be processed.
|
||||
w1_lora_b = w13_lora_b[:, ::2, :]
|
||||
@ -606,28 +640,6 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
|
||||
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
|
||||
|
||||
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
if self.tp_size == 1:
|
||||
return w2_lora_a
|
||||
# w2_lora_a shape (num_experts,rank,input_size)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
return w2_lora_a[:, :, start_idx:end_idx]
|
||||
|
||||
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w2_lora_b
|
||||
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
|
||||
# w2_lora_b shape (num_experts,output_size,rank)
|
||||
current_lora_size = w2_lora_b.shape[1]
|
||||
|
||||
sliced_size = current_lora_size // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_size
|
||||
end_idx = (self.tp_rank + 1) * sliced_size
|
||||
return w2_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -658,7 +670,7 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
w2_lora_b = w2_lora_b.permute(1, 0, 2)
|
||||
|
||||
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
|
||||
|
||||
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
|
||||
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
|
||||
@ -711,8 +723,8 @@ class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
|
||||
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
|
||||
# source_layer is FusedMoE or SharedFusedMoE
|
||||
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1
|
||||
|
||||
@ -197,7 +197,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
|
||||
@ -53,7 +53,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ReplicatedLinear
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
|
||||
@ -164,7 +164,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
|
||||
@ -131,7 +131,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
|
||||
@ -152,6 +152,59 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def pack_moe(
|
||||
cls, loras: GenericSequence[Optional["LoRALayerWeights"]], module_name: str
|
||||
) -> "PackedLoRALayerWeights":
|
||||
"""Pack a list of LoRAs into a single LoRA.
|
||||
|
||||
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||
"""
|
||||
|
||||
first_lora = next(lora for lora in loras if lora is not None)
|
||||
assert first_lora is not None
|
||||
rank = first_lora.rank
|
||||
lora_alpha = first_lora.lora_alpha
|
||||
assert len(loras) % 3 == 0
|
||||
w1_lora_a_lst = []
|
||||
w2_lora_a_lst = []
|
||||
w3_lora_a_lst = []
|
||||
w1_lora_b_lst = []
|
||||
w2_lora_b_lst = []
|
||||
w3_lora_b_lst = []
|
||||
# TODO: Consider the case where some experts don't have LoRA added.
|
||||
for eid in range(len(loras) // 3):
|
||||
w1_lora = loras[eid * 3]
|
||||
w2_lora = loras[eid * 3 + 1]
|
||||
w3_lora = loras[eid * 3 + 2]
|
||||
assert w1_lora is not None
|
||||
assert w2_lora is not None
|
||||
assert w3_lora is not None
|
||||
|
||||
w1_lora_a_lst.append(w1_lora.lora_a)
|
||||
w2_lora_a_lst.append(w2_lora.lora_a)
|
||||
w3_lora_a_lst.append(w3_lora.lora_a)
|
||||
|
||||
w1_lora_b_lst.append(w1_lora.lora_b)
|
||||
w2_lora_b_lst.append(w2_lora.lora_b)
|
||||
w3_lora_b_lst.append(w3_lora.lora_b)
|
||||
|
||||
w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
|
||||
w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
|
||||
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
|
||||
w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
|
||||
w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
|
||||
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
|
||||
|
||||
obj = cls(
|
||||
module_name,
|
||||
rank,
|
||||
[lora_alpha, lora_alpha, lora_alpha],
|
||||
[w1_lora_a, w2_lora_a, w3_lora_a],
|
||||
[w1_lora_b, w2_lora_b, w3_lora_b],
|
||||
)
|
||||
return obj
|
||||
|
||||
def optimize(self) -> "PackedLoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
for i in range(len(self.lora_b)):
|
||||
|
||||
@ -13,7 +13,7 @@ from torch import nn
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
@ -151,16 +151,13 @@ class LoRAModel:
|
||||
if pin_memory:
|
||||
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
|
||||
|
||||
for lora in loras.values():
|
||||
lora.optimize()
|
||||
|
||||
return cls(lora_model_id, peft_helper.r, loras)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
lora_dir: str,
|
||||
expected_lora_modules: list[str],
|
||||
expected_lora_modules: set[str],
|
||||
peft_helper: PEFTHelper,
|
||||
*,
|
||||
lora_model_id: int | None = None,
|
||||
@ -190,10 +187,7 @@ class LoRAModel:
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
|
||||
# new_embeddings_tensor_path = os.path.join(
|
||||
# lora_dir, "new_embeddings.safetensors"
|
||||
# )
|
||||
# new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
|
||||
|
||||
tensors: dict[str, torch.Tensor] = {}
|
||||
unexpected_modules: list[list[str] | str] = []
|
||||
|
||||
@ -201,18 +195,19 @@ class LoRAModel:
|
||||
for lora_module in modules.keys(): # noqa
|
||||
if is_base_embeddding_weights(lora_module):
|
||||
continue
|
||||
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||
# Handle FSDP file format where experts.base_layer is the
|
||||
# Handle PEFT file format where experts.base_layer is the
|
||||
# gate_up_proj and experts is the down_proj
|
||||
if "base_layer" in lora_module:
|
||||
continue
|
||||
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||
# Case for expert lora weights
|
||||
if ".experts" in module_name:
|
||||
if not any(
|
||||
module_name.endswith(ele) for ele in expected_lora_modules
|
||||
):
|
||||
expert_idx = module_name.find(".experts")
|
||||
expert_suffix = module_name[expert_idx + 1 :]
|
||||
if expert_suffix not in expected_lora_modules:
|
||||
unexpected_modules.append(module_name)
|
||||
elif module_name.split(".")[-1] not in expected_lora_modules:
|
||||
|
||||
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
|
||||
unexpected_modules.append(module_name)
|
||||
|
||||
if unexpected_modules:
|
||||
@ -358,9 +353,7 @@ class LoRAModelManager:
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a set for compatibility with LRUCache.
|
||||
self._last_mapping: LoRAMapping | None = None
|
||||
self._is_3d_moe_model = is_moe_model(self.model) and hasattr(
|
||||
self.model, "is_3d_moe_weight"
|
||||
)
|
||||
self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight
|
||||
self._create_lora_modules()
|
||||
|
||||
self.model.lora_manager = self
|
||||
@ -411,7 +404,7 @@ class LoRAModelManager:
|
||||
continue
|
||||
# Note (gnovack) - If MOE lora weights are not split into
|
||||
# num_experts chunks, we split them here
|
||||
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
|
||||
if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor(
|
||||
module_lora.lora_a
|
||||
):
|
||||
# Handle PEFT file format where experts.base_layer is the
|
||||
@ -679,7 +672,10 @@ class LoRAModelManager:
|
||||
"cpu",
|
||||
)
|
||||
subloras.append(lora)
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
if module.__class__.__name__ == "FusedMoEWithLoRA":
|
||||
lora = PackedLoRALayerWeights.pack_moe(subloras, module_name)
|
||||
else:
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
model.loras[module_name] = lora
|
||||
return model
|
||||
|
||||
@ -739,13 +735,21 @@ class LoRAModelManager:
|
||||
replaced_module_name = module_name.replace("model.", "")
|
||||
if lora_model.check_lora_name(module_name):
|
||||
module_name = replaced_module_name
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras
|
||||
)
|
||||
if module_name.endswith(".experts"):
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
|
||||
replacement_loras, module_name
|
||||
)
|
||||
else:
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras
|
||||
)
|
||||
# Remove the modules that have been replaced.
|
||||
for module in replaced_module:
|
||||
lora_model.loras.pop(module, None)
|
||||
|
||||
for lora in lora_model.loras.values():
|
||||
lora.optimize()
|
||||
|
||||
def _get_lora_layer_weights(
|
||||
self, lora_model: LoRAModel, module_name: str
|
||||
) -> LoRALayerWeights | None:
|
||||
|
||||
@ -170,16 +170,15 @@ def parse_fine_tuned_lora_name(
|
||||
|
||||
def is_base_embeddding_weights(name: str) -> bool:
|
||||
# hardcoded subfixes for input & output embedding weights
|
||||
input_embedding_subfix = ".embed_tokens.base_layer.weight"
|
||||
output_embedding_subfix = ".lm_head.base_layer.weight"
|
||||
|
||||
return name.endswith(input_embedding_subfix) or name.endswith(
|
||||
output_embedding_subfix
|
||||
embedding_suffixes = (
|
||||
".embed_tokens.base_layer.weight",
|
||||
".lm_head.base_layer.weight",
|
||||
)
|
||||
return name.endswith(embedding_suffixes)
|
||||
|
||||
|
||||
def is_regex_target_modules(
|
||||
load_modules: str | list[str], expected_lora_modules: list[str]
|
||||
load_modules: str | list[str], expected_lora_modules: set[str]
|
||||
) -> bool:
|
||||
"""
|
||||
PEFT supports passing `target_modules` in the form of regular expressions,
|
||||
@ -195,8 +194,8 @@ def is_regex_target_modules(
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
def is_subset(sub_list, full_list):
|
||||
return set(sub_list).issubset(set(full_list))
|
||||
def is_subset(sub_list, full_set):
|
||||
return set(sub_list).issubset(full_set)
|
||||
|
||||
# Similar to PEFT's processing logic, regex-related operations are only
|
||||
# executed when the load_modules is a `str`.
|
||||
@ -290,7 +289,7 @@ def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
|
||||
# the expert indices are expanded based on the configured number
|
||||
# of routed experts.
|
||||
packed_modules_mapping = get_packed_modules_mapping(model)
|
||||
if not hasattr(model, "is_3d_moe_weight"):
|
||||
if not model.is_3d_moe_weight:
|
||||
# 3D MoE LoRA does not need `packed_modules_mapping`
|
||||
packed_modules_mapping["experts"] = [
|
||||
weight_name.rstrip(".")
|
||||
|
||||
@ -88,15 +88,15 @@ class WorkerLoRAManager:
|
||||
try:
|
||||
supported_lora_modules = self._adapter_manager.supported_lora_modules
|
||||
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
|
||||
expected_lora_modules: list[str] = []
|
||||
expected_lora_lst: list[str] = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
expected_lora_lst.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
expected_lora_lst.append(module)
|
||||
if module == "experts":
|
||||
expected_lora_modules.append(module)
|
||||
expected_lora_modules = list(set(expected_lora_modules))
|
||||
expected_lora_lst.append(module)
|
||||
expected_lora_modules = set(expected_lora_lst)
|
||||
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
|
||||
@ -336,6 +336,7 @@ class SupportsLoRA(Protocol):
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
is_3d_moe_weight: ClassVar[bool] = False
|
||||
# The `embedding_module` and `embedding_padding_modules`
|
||||
# are empty by default.
|
||||
embedding_modules: ClassVar[dict[str, str]] = {}
|
||||
|
||||
@ -401,6 +401,7 @@ class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
|
||||
class Qwen3VLMoeForConditionalGeneration(
|
||||
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
|
||||
):
|
||||
is_3d_moe_weight: bool = True
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user