mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:15:01 +08:00
[Misc] Avoid loading incorrect LoRA config (#3777)
This commit is contained in:
parent
6c0b04515f
commit
11dd6ebb89
40
tests/lora/test_lora_checkpoints.py
Normal file
40
tests/lora/test_lora_checkpoints.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.lora.models import LoRAModel
|
||||||
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
|
||||||
|
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
|
||||||
|
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||||
|
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||||
|
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||||
|
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||||
|
expected_lora_modules = []
|
||||||
|
for module in supported_lora_modules:
|
||||||
|
if module in packed_modules_mapping:
|
||||||
|
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||||
|
else:
|
||||||
|
expected_lora_modules.append(module)
|
||||||
|
if lora_name == "baichuan7B":
|
||||||
|
# For the baichuan7B model, load it's LoRA,
|
||||||
|
# and the test should pass.
|
||||||
|
LoRAModel.from_local_checkpoint(
|
||||||
|
baichuan_lora_files,
|
||||||
|
expected_lora_modules,
|
||||||
|
lora_model_id=1,
|
||||||
|
device="cpu",
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embed_padding_modules)
|
||||||
|
else:
|
||||||
|
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||||
|
# and the test should raise the following error.
|
||||||
|
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
LoRAModel.from_local_checkpoint(
|
||||||
|
chatglm3_lora_files,
|
||||||
|
expected_lora_modules,
|
||||||
|
lora_model_id=1,
|
||||||
|
device="cpu",
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embed_padding_modules)
|
||||||
@ -191,6 +191,7 @@ class LoRAModel:
|
|||||||
def from_local_checkpoint(
|
def from_local_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
lora_dir: str,
|
lora_dir: str,
|
||||||
|
expected_lora_modules: List[str],
|
||||||
lora_model_id: Optional[int] = None,
|
lora_model_id: Optional[int] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
@ -206,6 +207,20 @@ class LoRAModel:
|
|||||||
lora_dir, "new_embeddings.safetensors")
|
lora_dir, "new_embeddings.safetensors")
|
||||||
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
||||||
"new_embeddings.bin")
|
"new_embeddings.bin")
|
||||||
|
with open(lora_config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
target_modules = config["target_modules"]
|
||||||
|
unexpected_modules = []
|
||||||
|
for module in target_modules:
|
||||||
|
if module not in expected_lora_modules:
|
||||||
|
unexpected_modules.append(module)
|
||||||
|
# loaded lora's target modules must be a subset of expected_lora_modules
|
||||||
|
if unexpected_modules:
|
||||||
|
raise ValueError(
|
||||||
|
f"While loading {lora_dir}, expected"
|
||||||
|
f" target modules in {expected_lora_modules}"
|
||||||
|
f" but received {unexpected_modules}."
|
||||||
|
f" Please verify that the loaded LoRA module is correct")
|
||||||
if os.path.isfile(lora_tensor_path):
|
if os.path.isfile(lora_tensor_path):
|
||||||
tensors = safetensors.torch.load_file(lora_tensor_path)
|
tensors = safetensors.torch.load_file(lora_tensor_path)
|
||||||
elif os.path.isfile(lora_bin_file_path):
|
elif os.path.isfile(lora_bin_file_path):
|
||||||
@ -220,8 +235,6 @@ class LoRAModel:
|
|||||||
elif os.path.isfile(new_embeddings_bin_file_path):
|
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||||
embeddings = torch.load(new_embeddings_bin_file_path)
|
embeddings = torch.load(new_embeddings_bin_file_path)
|
||||||
|
|
||||||
with open(lora_config_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
rank = config["r"]
|
rank = config["r"]
|
||||||
lora_alpha = config["lora_alpha"]
|
lora_alpha = config["lora_alpha"]
|
||||||
return cls.from_lora_tensors(
|
return cls.from_lora_tensors(
|
||||||
|
|||||||
@ -136,8 +136,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
|
|
||||||
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
||||||
try:
|
try:
|
||||||
|
model = self._lora_manager.model
|
||||||
|
supported_lora_modules = model.supported_lora_modules
|
||||||
|
packed_modules_mapping = model.packed_modules_mapping
|
||||||
|
expected_lora_modules = []
|
||||||
|
for module in supported_lora_modules:
|
||||||
|
if module in packed_modules_mapping:
|
||||||
|
expected_lora_modules.extend(
|
||||||
|
packed_modules_mapping[module])
|
||||||
|
else:
|
||||||
|
expected_lora_modules.append(module)
|
||||||
lora = self._lora_model_cls.from_local_checkpoint(
|
lora = self._lora_model_cls.from_local_checkpoint(
|
||||||
lora_request.lora_local_path,
|
lora_request.lora_local_path,
|
||||||
|
expected_lora_modules,
|
||||||
lora_model_id=lora_request.lora_int_id,
|
lora_model_id=lora_request.lora_int_id,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=self.lora_config.lora_dtype,
|
dtype=self.lora_config.lora_dtype,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user