[Bugfix] Fix missing lora name mapping for lora without prefix (#17793)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-05-08 00:17:12 +08:00 committed by GitHub
parent 646a31e51e
commit f98e307588
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 14 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections import OrderedDict from collections import OrderedDict
from typing import NamedTuple, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -9,52 +10,96 @@ from torch import nn
from vllm.lora.utils import (get_adapter_absolute_path, from vllm.lora.utils import (get_adapter_absolute_path,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.utils import WeightsMapper
class LoRANameParserTestConfig(NamedTuple):
name: str
module_name: str
is_lora_a: bool
is_bias: bool
weights_mapper: Optional[WeightsMapper] = None
def test_parse_fine_tuned_lora_name_valid(): def test_parse_fine_tuned_lora_name_valid():
fixture = { fixture = [
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight",
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), "lm_head", True, False),
( LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight",
"lm_head", False, False),
LoRANameParserTestConfig(
"base_model.model.model.embed_tokens.lora_embedding_A", "base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens", "model.embed_tokens",
True, True,
False, False,
), ),
( LoRANameParserTestConfig(
"base_model.model.model.embed_tokens.lora_embedding_B", "base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens", "model.embed_tokens",
False, False,
False, False,
), ),
( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
True, True,
False, False,
), ),
( LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj", "model.layers.9.mlp.down_proj",
False, False,
False, False,
), ),
( LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.layers.9.mlp.down_proj", "language_model.layers.9.mlp.down_proj",
True, True,
False, False,
), ),
( LoRANameParserTestConfig(
"language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.layers.9.mlp.down_proj", "language_model.layers.9.mlp.down_proj",
False, False,
False, False,
), ),
} # Test with WeightsMapper
for name, module_name, is_lora_a, is_bias in fixture: LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj",
True,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}),
),
LoRANameParserTestConfig(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj",
False,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}),
),
LoRANameParserTestConfig(
"model.layers.9.mlp.down_proj.lora_A.weight",
"language_model.model.layers.9.mlp.down_proj",
True,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}),
),
LoRANameParserTestConfig(
"model.layers.9.mlp.down_proj.lora_B.weight",
"language_model.model.layers.9.mlp.down_proj",
False,
False,
weights_mapper=WeightsMapper(
orig_to_new_prefix={"model.": "language_model.model."}),
),
]
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
assert (module_name, is_lora_a, assert (module_name, is_lora_a,
is_bias) == parse_fine_tuned_lora_name(name) is_bias) == parse_fine_tuned_lora_name(name, weights_mapper)
def test_parse_fine_tuned_lora_name_invalid(): def test_parse_fine_tuned_lora_name_invalid():

View File

@ -117,16 +117,18 @@ def parse_fine_tuned_lora_name(
# LoRA weight qualified name usually starts with `base_model.model.`, # LoRA weight qualified name usually starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following # so we remove the prefix `base_model.model.` to make the following
# mapping correctly. # mapping correctly.
if "base_model.model." in name: if name.startswith("base_model.model."):
name = name.replace("base_model.model.", "") name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.` # recover the prefix `base_model.model.`
name = "base_model.model." + name name = "base_model.model." + name
else:
name = weights_mapper._map_name(name) if weights_mapper else name
# In some situations, we may not start with `base_model.model.`. # In some situations, we may not start with `base_model.model.`.
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b), # If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
# we should keep the prefix intact. # we should keep the prefix intact.
start_index = 2 if "base_model.model." in name else 0 start_index = 2 if name.startswith("base_model.model.") else 0
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" if parts[-1] == "weight" and (parts[-2] == "lora_A"