mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 18:44:28 +08:00
[Bugfix] Fix MoE LoRA bin/pt loading (#31161)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
73cfb7a722
commit
27c6c2f98c
@ -12,7 +12,6 @@ from vllm.lora.peft_helper import PEFTHelper
|
|||||||
from vllm.lora.utils import (
|
from vllm.lora.utils import (
|
||||||
get_lora_id,
|
get_lora_id,
|
||||||
is_base_embeddding_weights,
|
is_base_embeddding_weights,
|
||||||
is_regex_target_modules,
|
|
||||||
parse_fine_tuned_lora_name,
|
parse_fine_tuned_lora_name,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
@ -201,37 +200,13 @@ class LoRAModel:
|
|||||||
for module in f.keys(): # noqa
|
for module in f.keys(): # noqa
|
||||||
tensors[module] = f.get_tensor(module)
|
tensors[module] = f.get_tensor(module)
|
||||||
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
|
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
|
||||||
# When a bin/pt file is provided, we rely on config to find
|
|
||||||
# unexpected modules.
|
|
||||||
unexpected_modules = []
|
|
||||||
target_modules = peft_helper.target_modules
|
|
||||||
if not isinstance(target_modules, list):
|
|
||||||
target_modules = [target_modules]
|
|
||||||
for module in target_modules:
|
|
||||||
# Compatible with more modules,
|
|
||||||
# such as:layers.11.self_attn.k_proj
|
|
||||||
part_name = module.split(".")[-1]
|
|
||||||
if part_name not in expected_lora_modules:
|
|
||||||
unexpected_modules.append(module)
|
|
||||||
# loaded lora's target modules must be a subset of
|
|
||||||
# expected_lora_modules. It is not reliable. See
|
|
||||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
|
||||||
# other better mechanism.
|
|
||||||
if unexpected_modules and not is_regex_target_modules(
|
|
||||||
peft_helper.target_modules, expected_lora_modules
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"While loading {lora_dir}, expected"
|
|
||||||
f" target modules in {expected_lora_modules}"
|
|
||||||
f" but received {unexpected_modules}."
|
|
||||||
f" Please verify that the loaded LoRA module is correct"
|
|
||||||
)
|
|
||||||
lora_file_path = (
|
lora_file_path = (
|
||||||
lora_bin_file_path
|
lora_bin_file_path
|
||||||
if os.path.isfile(lora_bin_file_path)
|
if os.path.isfile(lora_bin_file_path)
|
||||||
else lora_pt_file_path
|
else lora_pt_file_path
|
||||||
)
|
)
|
||||||
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
|
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
|
||||||
|
check_unexpected_modules(tensors)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import os
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import regex as re
|
|
||||||
from huggingface_hub.utils import (
|
from huggingface_hub.utils import (
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
HfHubHTTPError,
|
HfHubHTTPError,
|
||||||
@ -186,39 +185,6 @@ def is_base_embeddding_weights(name: str) -> bool:
|
|||||||
return name.endswith(embedding_suffixes)
|
return name.endswith(embedding_suffixes)
|
||||||
|
|
||||||
|
|
||||||
def is_regex_target_modules(
|
|
||||||
load_modules: str | list[str], expected_lora_modules: set[str]
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
PEFT supports passing `target_modules` in the form of regular expressions,
|
|
||||||
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
|
|
||||||
determine whether the suffix in the regular expression is present in the
|
|
||||||
`expected_lora_modules`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def is_valid_regex(pattern):
|
|
||||||
try:
|
|
||||||
re.compile(pattern)
|
|
||||||
return True
|
|
||||||
except re.error:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_subset(sub_list, full_set):
|
|
||||||
return set(sub_list).issubset(full_set)
|
|
||||||
|
|
||||||
# Similar to PEFT's processing logic, regex-related operations are only
|
|
||||||
# executed when the load_modules is a `str`.
|
|
||||||
if not isinstance(load_modules, str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if is_valid_regex(load_modules):
|
|
||||||
match = re.search(r"\((.*?)\)\$?$", load_modules)
|
|
||||||
if match:
|
|
||||||
suffix = match.group(1).split("|")
|
|
||||||
return is_subset(suffix, expected_lora_modules)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_supported_lora_modules(model: nn.Module) -> list[str]:
|
def get_supported_lora_modules(model: nn.Module) -> list[str]:
|
||||||
"""
|
"""
|
||||||
In vLLM, all linear layers support LoRA.
|
In vLLM, all linear layers support LoRA.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user