mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:35:43 +08:00
[Bugfix] Fix Lora Name Parsing (#17196)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
18445edd0f
commit
756848e79e
@ -39,6 +39,18 @@ def test_parse_fine_tuned_lora_name_valid():
|
||||
False,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
"language_model.layers.9.mlp.down_proj",
|
||||
True,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
"language_model.layers.9.mlp.down_proj",
|
||||
False,
|
||||
False,
|
||||
),
|
||||
}
|
||||
for name, module_name, is_lora_a, is_bias in fixture:
|
||||
assert (module_name, is_lora_a,
|
||||
|
||||
@ -114,7 +114,7 @@ def parse_fine_tuned_lora_name(
|
||||
is_bias whether the tensor is lora bias.
|
||||
"""
|
||||
|
||||
# LoRA weight qualified name always starts with `base_model.model.`,
|
||||
# LoRA weight qualified name usually starts with `base_model.model.`,
|
||||
# so we remove the prefix `base_model.model.` to make the following
|
||||
# mapping correctly.
|
||||
if "base_model.model." in name:
|
||||
@ -123,18 +123,23 @@ def parse_fine_tuned_lora_name(
|
||||
# recover the prefix `base_model.model.`
|
||||
name = "base_model.model." + name
|
||||
|
||||
# In some situations, we may not start with `base_model.model.`.
|
||||
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
|
||||
# we should keep the prefix intact.
|
||||
start_index = 2 if "base_model.model." in name else 0
|
||||
|
||||
parts = name.split(".")
|
||||
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||
or parts[-2] == "lora_B"):
|
||||
new_name = ".".join(parts[2:-2])
|
||||
new_name = ".".join(parts[start_index:-2])
|
||||
return new_name, parts[-2] == "lora_A", False
|
||||
|
||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
new_name = ".".join(parts[2:-1])
|
||||
new_name = ".".join(parts[start_index:-1])
|
||||
return new_name, parts[-1] == "lora_embedding_A", False
|
||||
|
||||
if parts[-1] == "bias":
|
||||
new_name = ".".join(parts[2:-2])
|
||||
new_name = ".".join(parts[start_index:-2])
|
||||
return new_name, False, True
|
||||
|
||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user