mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 08:15:01 +08:00
[Bugfix] Fix missing lora name mapping for lora without prefix (#17793)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
646a31e51e
commit
f98e307588
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import NamedTuple, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -9,52 +10,96 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.lora.utils import (get_adapter_absolute_path,
|
from vllm.lora.utils import (get_adapter_absolute_path,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
|
||||||
|
|
||||||
|
class LoRANameParserTestConfig(NamedTuple):
|
||||||
|
name: str
|
||||||
|
module_name: str
|
||||||
|
is_lora_a: bool
|
||||||
|
is_bias: bool
|
||||||
|
weights_mapper: Optional[WeightsMapper] = None
|
||||||
|
|
||||||
|
|
||||||
def test_parse_fine_tuned_lora_name_valid():
|
def test_parse_fine_tuned_lora_name_valid():
|
||||||
fixture = {
|
fixture = [
|
||||||
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
|
LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight",
|
||||||
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
|
"lm_head", True, False),
|
||||||
(
|
LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight",
|
||||||
|
"lm_head", False, False),
|
||||||
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.embed_tokens.lora_embedding_A",
|
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||||
"model.embed_tokens",
|
"model.embed_tokens",
|
||||||
True,
|
True,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.embed_tokens.lora_embedding_B",
|
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||||
"model.embed_tokens",
|
"model.embed_tokens",
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"model.layers.9.mlp.down_proj",
|
"model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
(
|
LoRANameParserTestConfig(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"model.layers.9.mlp.down_proj",
|
"model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
(
|
LoRANameParserTestConfig(
|
||||||
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"language_model.layers.9.mlp.down_proj",
|
"language_model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
(
|
LoRANameParserTestConfig(
|
||||||
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"language_model.layers.9.mlp.down_proj",
|
"language_model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
}
|
# Test with WeightsMapper
|
||||||
for name, module_name, is_lora_a, is_bias in fixture:
|
LoRANameParserTestConfig(
|
||||||
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
weights_mapper=WeightsMapper(
|
||||||
|
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||||
|
),
|
||||||
|
LoRANameParserTestConfig(
|
||||||
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
weights_mapper=WeightsMapper(
|
||||||
|
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||||
|
),
|
||||||
|
LoRANameParserTestConfig(
|
||||||
|
"model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
weights_mapper=WeightsMapper(
|
||||||
|
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||||
|
),
|
||||||
|
LoRANameParserTestConfig(
|
||||||
|
"model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
|
"language_model.model.layers.9.mlp.down_proj",
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
weights_mapper=WeightsMapper(
|
||||||
|
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
|
||||||
assert (module_name, is_lora_a,
|
assert (module_name, is_lora_a,
|
||||||
is_bias) == parse_fine_tuned_lora_name(name)
|
is_bias) == parse_fine_tuned_lora_name(name, weights_mapper)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_fine_tuned_lora_name_invalid():
|
def test_parse_fine_tuned_lora_name_invalid():
|
||||||
|
|||||||
@ -117,16 +117,18 @@ def parse_fine_tuned_lora_name(
|
|||||||
# LoRA weight qualified name usually starts with `base_model.model.`,
|
# LoRA weight qualified name usually starts with `base_model.model.`,
|
||||||
# so we remove the prefix `base_model.model.` to make the following
|
# so we remove the prefix `base_model.model.` to make the following
|
||||||
# mapping correctly.
|
# mapping correctly.
|
||||||
if "base_model.model." in name:
|
if name.startswith("base_model.model."):
|
||||||
name = name.replace("base_model.model.", "")
|
name = name.replace("base_model.model.", "")
|
||||||
name = weights_mapper._map_name(name) if weights_mapper else name
|
name = weights_mapper._map_name(name) if weights_mapper else name
|
||||||
# recover the prefix `base_model.model.`
|
# recover the prefix `base_model.model.`
|
||||||
name = "base_model.model." + name
|
name = "base_model.model." + name
|
||||||
|
else:
|
||||||
|
name = weights_mapper._map_name(name) if weights_mapper else name
|
||||||
|
|
||||||
# In some situations, we may not start with `base_model.model.`.
|
# In some situations, we may not start with `base_model.model.`.
|
||||||
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
|
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
|
||||||
# we should keep the prefix intact.
|
# we should keep the prefix intact.
|
||||||
start_index = 2 if "base_model.model." in name else 0
|
start_index = 2 if name.startswith("base_model.model.") else 0
|
||||||
|
|
||||||
parts = name.split(".")
|
parts = name.split(".")
|
||||||
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user