mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:15:01 +08:00
[Lora] Use safetensor keys instead of adapter_config.json to find unexpected modules. (#5909)
Co-authored-by: sang <sangcho@anyscale.com>
This commit is contained in:
parent
c6c240aa0a
commit
f5e73c9f1b
@ -226,3 +226,4 @@ steps:
|
|||||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- pytest -v -s -x lora/test_mixtral.py
|
||||||
|
|||||||
@ -165,7 +165,9 @@ def sql_lora_files():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def mixtral_lora_files():
|
def mixtral_lora_files():
|
||||||
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
|
# Note: this module has incorrect adapter_config.json to test
|
||||||
|
# https://github.com/vllm-project/vllm/pull/5909/files.
|
||||||
|
return snapshot_download(repo_id="SangBinCho/mixtral-lora")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
@ -40,14 +40,14 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
|
|||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_num_seqs=16,
|
max_num_seqs=16,
|
||||||
max_loras=4,
|
max_loras=4,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
tensor_parallel_size=tp_size)
|
tensor_parallel_size=tp_size)
|
||||||
|
|
||||||
expected_lora_output = [
|
expected_lora_output = [
|
||||||
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
|
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
|
||||||
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
|
"give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501
|
||||||
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
|
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
|
||||||
]
|
]
|
||||||
|
|
||||||
assert do_sample(llm, mixtral_lora_files,
|
assert do_sample(llm, mixtral_lora_files,
|
||||||
lora_id=1) == expected_lora_output
|
lora_id=1) == expected_lora_output
|
||||||
assert do_sample(llm, mixtral_lora_files,
|
assert do_sample(llm, mixtral_lora_files,
|
||||||
|
|||||||
@ -303,15 +303,47 @@ class LoRAModel:
|
|||||||
"new_embeddings.bin")
|
"new_embeddings.bin")
|
||||||
with open(lora_config_path) as f:
|
with open(lora_config_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
target_modules = config["target_modules"]
|
if os.path.isfile(lora_tensor_path):
|
||||||
|
tensors: Dict[str, torch.Tensor] = {}
|
||||||
|
# Find unexpected modules.
|
||||||
|
# Use safetensor key as a source of truth to find expected modules.
|
||||||
|
# in peft if you have target_modules A, B, C and C does not exist
|
||||||
|
# in the model it won’t error and model will be trained with A, B
|
||||||
|
# loraified. C won’t exist in the safetensor but it will exist in
|
||||||
|
# the target_modules of the adapter_config.json.
|
||||||
unexpected_modules = []
|
unexpected_modules = []
|
||||||
|
with safetensors.safe_open(lora_tensor_path,
|
||||||
|
framework="pt") as f: # type: ignore
|
||||||
|
for lora_module in f.keys(): # noqa
|
||||||
|
module_name, _ = parse_fine_tuned_lora_name(lora_module)
|
||||||
|
part_name = module_name.split(".")[-1]
|
||||||
|
if part_name not in expected_lora_modules:
|
||||||
|
unexpected_modules.append(module_name)
|
||||||
|
if unexpected_modules:
|
||||||
|
raise ValueError(
|
||||||
|
f"While loading {lora_dir}, expected"
|
||||||
|
f" target modules in {expected_lora_modules}"
|
||||||
|
f" but received {unexpected_modules}."
|
||||||
|
f" Please verify that the loaded LoRA module is correct"
|
||||||
|
)
|
||||||
|
# Load tensors if there are only expected modules.
|
||||||
|
for module in f.keys(): # noqa
|
||||||
|
tensors[module] = f.get_tensor(module)
|
||||||
|
elif os.path.isfile(lora_bin_file_path):
|
||||||
|
# When a bin file is provided, we rely on config to find unexpected
|
||||||
|
# modules.
|
||||||
|
unexpected_modules = []
|
||||||
|
target_modules = config["target_modules"]
|
||||||
for module in target_modules:
|
for module in target_modules:
|
||||||
# Compatible with more modules, such as:layers.11.self_attn.k_proj
|
# Compatible with more modules,
|
||||||
|
# such as:layers.11.self_attn.k_proj
|
||||||
part_name = module.split(".")[-1]
|
part_name = module.split(".")[-1]
|
||||||
if part_name not in expected_lora_modules:
|
if part_name not in expected_lora_modules:
|
||||||
unexpected_modules.append(module)
|
unexpected_modules.append(module)
|
||||||
# loaded lora's target modules must be a subset of expected_lora_modules
|
# loaded lora's target modules must be a subset of
|
||||||
|
# expected_lora_modules. It is not reliable. See
|
||||||
|
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||||
|
# other better mechanism.
|
||||||
if unexpected_modules:
|
if unexpected_modules:
|
||||||
print(unexpected_modules, "modules")
|
print(unexpected_modules, "modules")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -319,9 +351,6 @@ class LoRAModel:
|
|||||||
f" target modules in {expected_lora_modules}"
|
f" target modules in {expected_lora_modules}"
|
||||||
f" but received {unexpected_modules}."
|
f" but received {unexpected_modules}."
|
||||||
f" Please verify that the loaded LoRA module is correct")
|
f" Please verify that the loaded LoRA module is correct")
|
||||||
if os.path.isfile(lora_tensor_path):
|
|
||||||
tensors = safetensors.torch.load_file(lora_tensor_path)
|
|
||||||
elif os.path.isfile(lora_bin_file_path):
|
|
||||||
tensors = torch.load(lora_bin_file_path)
|
tensors = torch.load(lora_bin_file_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user