mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 18:07:00 +08:00
[Misc] Improve error message when LoRA parsing fails (#5194)
This commit is contained in:
parent
c81da5f56d
commit
0bfa1c4f13
@ -1,12 +1,13 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import pytest
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||||
from vllm.utils import LRUCache
|
from vllm.utils import LRUCache
|
||||||
|
|
||||||
|
|
||||||
def test_parse_fine_tuned_lora_name():
|
def test_parse_fine_tuned_lora_name_valid():
|
||||||
fixture = {
|
fixture = {
|
||||||
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
|
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
|
||||||
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
|
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
|
||||||
@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name():
|
|||||||
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
|
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_fine_tuned_lora_name_invalid():
|
||||||
|
fixture = {
|
||||||
|
"weight",
|
||||||
|
"base_model.weight",
|
||||||
|
"base_model.model.weight",
|
||||||
|
}
|
||||||
|
for name in fixture:
|
||||||
|
with pytest.raises(ValueError, match="unsupported LoRA weight"):
|
||||||
|
parse_fine_tuned_lora_name(name)
|
||||||
|
|
||||||
|
|
||||||
def test_replace_submodule():
|
def test_replace_submodule():
|
||||||
model = nn.Sequential(
|
model = nn.Sequential(
|
||||||
OrderedDict([
|
OrderedDict([
|
||||||
|
|||||||
@ -94,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
|||||||
is_lora_a whether the tensor is lora_a or lora_b.
|
is_lora_a whether the tensor is lora_a or lora_b.
|
||||||
"""
|
"""
|
||||||
parts = name.split(".")
|
parts = name.split(".")
|
||||||
assert parts[0] == "base_model"
|
|
||||||
assert parts[1] == "model"
|
|
||||||
if parts[-1] == "weight":
|
|
||||||
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
|
|
||||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
|
||||||
|
|
||||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
|
||||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
if parts[-1] == "weight":
|
||||||
|
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
|
||||||
|
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
||||||
|
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||||
|
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
||||||
|
|
||||||
raise ValueError(f"{name} is unsupported format")
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user