mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[V0 Deprecation] Remove V0 LoRA test (#23418)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
88016c372a
commit
285178b3b8
@ -3,15 +3,13 @@
|
||||
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -104,6 +101,7 @@ def dummy_model() -> nn.Module:
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
return model
|
||||
|
||||
|
||||
@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
],
|
||||
}
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ -221,29 +221,6 @@ def phi2_lora_files():
|
||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(**kwargs):
|
||||
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
|
||||
max_lora_rank=8)
|
||||
return get_model_old(**kwargs)
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
yield engine.llm_engine
|
||||
del engine
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
|
||||
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
|
||||
@ -5,7 +5,6 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm.envs as env
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files):
|
||||
# Run with warmup
|
||||
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
|
||||
add_lora_results = await asyncio.gather(*add_lora_tasks)
|
||||
if env.VLLM_USE_V1:
|
||||
|
||||
# Test that all all_lora calls are successful.
|
||||
assert all(add_lora_results)
|
||||
else:
|
||||
# No way to check V0 engine results as the calls just return None.
|
||||
pass
|
||||
|
||||
time_with_add_lora = await requests_processing_time(
|
||||
llm, warmup_run_requests)
|
||||
|
||||
|
||||
@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files):
|
||||
enable_lora=True,
|
||||
# also test odd max_num_seqs
|
||||
max_num_seqs=13,
|
||||
max_loras=4,
|
||||
enable_chunked_prefill=True)
|
||||
max_loras=4)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files):
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
fully_sharded_loras=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import create_peft_lora
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
@ -35,17 +37,6 @@ DEVICES = ([
|
||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Some tests depend on V0 internals. Since both V0 and V1 use the same
|
||||
LoRAModelManager it is okay to just test V0.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv('VLLM_USE_V1', '0')
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# Add up to capacity
|
||||
@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
tmp_path):
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
4, 2,
|
||||
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
|
||||
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
LoRARequest("13", 13, sql_lora_files),
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
tmp_path):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
4, 2, dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model_gate_up,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
||||
@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
LoRARequest("13", 13, sql_lora_files),
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
|
||||
@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||
max_loras=4,
|
||||
distributed_executor_backend="ray",
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
expected_lora_output = [
|
||||
|
||||
@ -4,17 +4,14 @@
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.gpu_worker import Worker as V1Worker
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
NUM_LORAS = 16
|
||||
|
||||
@ -22,19 +19,12 @@ NUM_LORAS = 16
|
||||
@patch.dict(os.environ, {"RANK": "0"})
|
||||
def test_worker_apply_lora(sql_lora_files):
|
||||
|
||||
def set_active_loras(worker: Union[Worker, V1Worker],
|
||||
lora_requests: list[LoRARequest]):
|
||||
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
|
||||
lora_mapping = LoRAMapping([], [])
|
||||
if isinstance(worker, Worker):
|
||||
# v0 case
|
||||
worker.model_runner.set_active_loras(lora_requests, lora_mapping)
|
||||
else:
|
||||
# v1 case
|
||||
|
||||
worker.model_runner.lora_manager.set_active_adapters(
|
||||
lora_requests, lora_mapping)
|
||||
|
||||
worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
max_cpu_loras=NUM_LORAS,
|
||||
max_loras=NUM_LORAS),
|
||||
)
|
||||
worker = worker_cls(
|
||||
worker = Worker(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
@ -340,3 +343,76 @@ def generate_data_for_nslices(
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
)
|
||||
|
||||
|
||||
def create_peft_lora(
|
||||
model: torch.nn.Module,
|
||||
save_dir: str,
|
||||
target_modules: list[str],
|
||||
rank: int = 8,
|
||||
alpha: int = 16,
|
||||
dropout: float = 0.1,
|
||||
lora_dtype: torch.dtype = torch.float16,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
lora_weights = {}
|
||||
adapter_config = {
|
||||
"peft_type": "LORA",
|
||||
"auto_mapping": None,
|
||||
"base_model_name_or_path": "dummy_model",
|
||||
"revision": None,
|
||||
"task_type": "CAUSAL_LM",
|
||||
"inference_mode": False,
|
||||
"r": rank,
|
||||
"lora_alpha": alpha,
|
||||
"lora_dropout": dropout,
|
||||
"fan_in_fan_out": False,
|
||||
"bias": "none",
|
||||
"modules_to_save": None,
|
||||
"init_lora_weights": True,
|
||||
"layers_to_transform": None,
|
||||
"layers_pattern": None,
|
||||
"target_modules": target_modules,
|
||||
"exclude_modules": None,
|
||||
"use_rslora": False,
|
||||
"use_dora": False,
|
||||
"loftq_config": None,
|
||||
}
|
||||
|
||||
for module_name in target_modules:
|
||||
|
||||
module = model
|
||||
for attr in module_name.split("."):
|
||||
module = getattr(module, attr)
|
||||
|
||||
if hasattr(module, "input_size") and hasattr(module, "output_size"):
|
||||
|
||||
in_features = module.input_size
|
||||
out_features = module.output_size
|
||||
|
||||
elif hasattr(module, "embedding_dim") and hasattr(
|
||||
module, "num_embeddings"):
|
||||
# ParallelLMHead
|
||||
in_features = module.embedding_dim
|
||||
out_features = module.num_embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to determine dimensions for module {module_name}")
|
||||
|
||||
lora_A = torch.randn(rank, in_features, dtype=lora_dtype)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)
|
||||
|
||||
lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)
|
||||
|
||||
# PEFT style
|
||||
lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
|
||||
lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B
|
||||
|
||||
config_path = os.path.join(save_dir, "adapter_config.json")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(adapter_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
weights_path = os.path.join(save_dir, "adapter_model.safetensors")
|
||||
save_file(lora_weights, weights_path)
|
||||
|
||||
return lora_weights
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user