mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Misc][LoRA] Fix LoRA weight mapper (#11495)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
dbeac95dbb
commit
aa25985bd1
@ -74,7 +74,7 @@ def test_load_checkpoints(
|
|||||||
embedding_padding_modules=embed_padding_modules)
|
embedding_padding_modules=embed_padding_modules)
|
||||||
|
|
||||||
|
|
||||||
def test_lora_weights_mapping(baichuan_lora_files, ):
|
def test_lora_weights_mapping(baichuan_lora_files):
|
||||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||||
@ -86,10 +86,14 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
|
|||||||
else:
|
else:
|
||||||
expected_lora_modules.append(module)
|
expected_lora_modules.append(module)
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
"model.": "language_model.model.",
|
orig_to_new_prefix={
|
||||||
}, )
|
"model.": "language_model.model.",
|
||||||
|
},
|
||||||
|
orig_to_new_substr={
|
||||||
|
".layers.": ".baichuan_layers.",
|
||||||
|
},
|
||||||
|
)
|
||||||
lora_model = LoRAModel.from_local_checkpoint(
|
lora_model = LoRAModel.from_local_checkpoint(
|
||||||
baichuan_lora_files,
|
baichuan_lora_files,
|
||||||
expected_lora_modules,
|
expected_lora_modules,
|
||||||
@ -101,3 +105,4 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
|
|||||||
)
|
)
|
||||||
for name in lora_model.loras:
|
for name in lora_model.loras:
|
||||||
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
|
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
|
||||||
|
assert ".baichuan_layers." in name
|
||||||
|
|||||||
@ -22,7 +22,7 @@ IMAGE_ASSETS = [
|
|||||||
|
|
||||||
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
||||||
EXPECTED_OUTPUT = [
|
EXPECTED_OUTPUT = [
|
||||||
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
|
"A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
|
||||||
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -76,3 +76,7 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
|
|||||||
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
|
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
|
||||||
for i in range(len(EXPECTED_OUTPUT)):
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
assert EXPECTED_OUTPUT[i].startswith(output1[i])
|
assert EXPECTED_OUTPUT[i].startswith(output1[i])
|
||||||
|
|
||||||
|
output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2)
|
||||||
|
for i in range(len(EXPECTED_OUTPUT)):
|
||||||
|
assert EXPECTED_OUTPUT[i].startswith(output2[i])
|
||||||
|
|||||||
@ -231,7 +231,8 @@ class LoRAModel(AdapterModel):
|
|||||||
with safetensors.safe_open(lora_tensor_path,
|
with safetensors.safe_open(lora_tensor_path,
|
||||||
framework="pt") as f: # type: ignore
|
framework="pt") as f: # type: ignore
|
||||||
for lora_module in f.keys(): # noqa
|
for lora_module in f.keys(): # noqa
|
||||||
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
|
module_name, _, _ = parse_fine_tuned_lora_name(
|
||||||
|
lora_module, weights_mapper)
|
||||||
part_name = module_name.split(".")[-1]
|
part_name = module_name.split(".")[-1]
|
||||||
if part_name not in expected_lora_modules:
|
if part_name not in expected_lora_modules:
|
||||||
unexpected_modules.append(module_name)
|
unexpected_modules.append(module_name)
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import copy
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Set, Tuple, Type, Union
|
from typing import List, Optional, Set, Tuple, Type, Union
|
||||||
@ -32,7 +31,6 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -112,36 +110,28 @@ def parse_fine_tuned_lora_name(
|
|||||||
is_bias whether the tensor is lora bias.
|
is_bias whether the tensor is lora bias.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
w_mapper = None
|
# LoRA weight qualified name always starts with `base_model.model.`,
|
||||||
if weights_mapper:
|
# so we remove the prefix `base_model.model.` to make the following
|
||||||
w_mapper = copy.deepcopy(weights_mapper)
|
# mapping correctly.
|
||||||
# TODO: Currently only supports mapping for prefix, mapping for
|
if "base_model.model." in name:
|
||||||
# substr and subfix will be supported in the future.
|
name = name.replace("base_model.model.", "")
|
||||||
for attr, mapping in [
|
name = weights_mapper._map_name(name) if weights_mapper else name
|
||||||
("orig_to_new_substr", w_mapper.orig_to_new_substr),
|
# recover the prefix `base_model.model.`
|
||||||
("orig_to_new_suffix", w_mapper.orig_to_new_suffix),
|
name = "base_model.model." + name
|
||||||
]:
|
|
||||||
if mapping:
|
|
||||||
print_warning_once(
|
|
||||||
f"vLLM currently does not support mapping of LoRA weights "
|
|
||||||
f"for {mapping}.")
|
|
||||||
setattr(w_mapper, attr, {})
|
|
||||||
|
|
||||||
mapper = (lambda name: w_mapper._map_name(name)
|
|
||||||
if w_mapper is not None else name)
|
|
||||||
parts = name.split(".")
|
parts = name.split(".")
|
||||||
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||||
or parts[-2] == "lora_B"):
|
or parts[-2] == "lora_B"):
|
||||||
new_name = ".".join(parts[2:-2])
|
new_name = ".".join(parts[2:-2])
|
||||||
return mapper(new_name), parts[-2] == "lora_A", False
|
return new_name, parts[-2] == "lora_A", False
|
||||||
|
|
||||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||||
new_name = ".".join(parts[2:-1])
|
new_name = ".".join(parts[2:-1])
|
||||||
return mapper(new_name), parts[-1] == "lora_embedding_A", False
|
return new_name, parts[-1] == "lora_embedding_A", False
|
||||||
|
|
||||||
if parts[-1] == "bias":
|
if parts[-1] == "bias":
|
||||||
new_name = ".".join(parts[2:-2])
|
new_name = ".".join(parts[2:-2])
|
||||||
return mapper(new_name), False, True
|
return new_name, False, True
|
||||||
|
|
||||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||||
|
|
||||||
|
|||||||
@ -91,6 +91,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
packed_modules_mapping[module])
|
packed_modules_mapping[module])
|
||||||
else:
|
else:
|
||||||
expected_lora_modules.append(module)
|
expected_lora_modules.append(module)
|
||||||
|
|
||||||
|
expected_lora_modules = list(set(expected_lora_modules))
|
||||||
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||||
|
|
||||||
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user