[Misc][LoRA] Fix LoRA weight mapper (#11495)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-26 15:52:48 +08:00 committed by GitHub
parent dbeac95dbb
commit aa25985bd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 30 additions and 28 deletions

View File

@ -74,7 +74,7 @@ def test_load_checkpoints(
embedding_padding_modules=embed_padding_modules) embedding_padding_modules=embed_padding_modules)
def test_lora_weights_mapping(baichuan_lora_files, ): def test_lora_weights_mapping(baichuan_lora_files):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
@ -86,10 +86,14 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ hf_to_vllm_mapper = WeightsMapper(
"model.": "language_model.model.", orig_to_new_prefix={
}, ) "model.": "language_model.model.",
},
orig_to_new_substr={
".layers.": ".baichuan_layers.",
},
)
lora_model = LoRAModel.from_local_checkpoint( lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files, baichuan_lora_files,
expected_lora_modules, expected_lora_modules,
@ -101,3 +105,4 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
) )
for name in lora_model.loras: for name in lora_model.loras:
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."]) assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
assert ".baichuan_layers." in name

View File

@ -22,7 +22,7 @@ IMAGE_ASSETS = [
# After fine-tuning with LoRA, all generated content should start begin `A`. # After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [ EXPECTED_OUTPUT = [
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501 "A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501 "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
] ]
@ -76,3 +76,7 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1) output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)): for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i]) assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])

View File

@ -231,7 +231,8 @@ class LoRAModel(AdapterModel):
with safetensors.safe_open(lora_tensor_path, with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa for lora_module in f.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(lora_module) module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1] part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module_name) unexpected_modules.append(module_name)

View File

@ -1,4 +1,3 @@
import copy
import os import os
import re import re
from typing import List, Optional, Set, Tuple, Type, Union from typing import List, Optional, Set, Tuple, Type, Union
@ -32,7 +31,6 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
@ -112,36 +110,28 @@ def parse_fine_tuned_lora_name(
is_bias whether the tensor is lora bias. is_bias whether the tensor is lora bias.
""" """
w_mapper = None # LoRA weight qualified name always starts with `base_model.model.`,
if weights_mapper: # so we remove the prefix `base_model.model.` to make the following
w_mapper = copy.deepcopy(weights_mapper) # mapping correctly.
# TODO: Currently only supports mapping for prefix, mapping for if "base_model.model." in name:
# substr and subfix will be supported in the future. name = name.replace("base_model.model.", "")
for attr, mapping in [ name = weights_mapper._map_name(name) if weights_mapper else name
("orig_to_new_substr", w_mapper.orig_to_new_substr), # recover the prefix `base_model.model.`
("orig_to_new_suffix", w_mapper.orig_to_new_suffix), name = "base_model.model." + name
]:
if mapping:
print_warning_once(
f"vLLM currently does not support mapping of LoRA weights "
f"for {mapping}.")
setattr(w_mapper, attr, {})
mapper = (lambda name: w_mapper._map_name(name)
if w_mapper is not None else name)
parts = name.split(".") parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"): or parts[-2] == "lora_B"):
new_name = ".".join(parts[2:-2]) new_name = ".".join(parts[2:-2])
return mapper(new_name), parts[-2] == "lora_A", False return new_name, parts[-2] == "lora_A", False
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[2:-1]) new_name = ".".join(parts[2:-1])
return mapper(new_name), parts[-1] == "lora_embedding_A", False return new_name, parts[-1] == "lora_embedding_A", False
if parts[-1] == "bias": if parts[-1] == "bias":
new_name = ".".join(parts[2:-2]) new_name = ".".join(parts[2:-2])
return mapper(new_name), False, True return new_name, False, True
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")

View File

@ -91,6 +91,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping[module]) packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper # For some models like Qwen2VL, we need to use hf_to_vllm_mapper