mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[V0 Deprecation] Remove V0 LoRA test (#23418)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
88016c372a
commit
285178b3b8
@ -3,15 +3,13 @@
|
|||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
import vllm
|
|
||||||
from vllm.config import LoRAConfig
|
|
||||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader import get_model
|
|
||||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -104,6 +101,7 @@ def dummy_model() -> nn.Module:
|
|||||||
]))
|
]))
|
||||||
model.config = MagicMock()
|
model.config = MagicMock()
|
||||||
model.embedding_modules = {"lm_head": "lm_head"}
|
model.embedding_modules = {"lm_head": "lm_head"}
|
||||||
|
model.unpadded_vocab_size = 32000
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
model.embedding_modules = {"lm_head": "lm_head"}
|
model.embedding_modules = {"lm_head": "lm_head"}
|
||||||
|
model.unpadded_vocab_size = 32000
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -221,29 +221,6 @@ def phi2_lora_files():
|
|||||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def llama_2_7b_engine_extra_embeddings():
|
|
||||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
|
||||||
get_model_old = get_model
|
|
||||||
|
|
||||||
def get_model_patched(**kwargs):
|
|
||||||
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
|
|
||||||
max_lora_rank=8)
|
|
||||||
return get_model_old(**kwargs)
|
|
||||||
|
|
||||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
|
||||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
|
||||||
yield engine.llm_engine
|
|
||||||
del engine
|
|
||||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
|
|
||||||
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reset_default_device():
|
def reset_default_device():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import time
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import vllm.envs as env
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
build_async_engine_client_from_engine_args)
|
build_async_engine_client_from_engine_args)
|
||||||
@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files):
|
|||||||
# Run with warmup
|
# Run with warmup
|
||||||
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
|
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
|
||||||
add_lora_results = await asyncio.gather(*add_lora_tasks)
|
add_lora_results = await asyncio.gather(*add_lora_tasks)
|
||||||
if env.VLLM_USE_V1:
|
|
||||||
# Test that all all_lora calls are successful.
|
# Test that all all_lora calls are successful.
|
||||||
assert all(add_lora_results)
|
assert all(add_lora_results)
|
||||||
else:
|
|
||||||
# No way to check V0 engine results as the calls just return None.
|
|
||||||
pass
|
|
||||||
time_with_add_lora = await requests_processing_time(
|
time_with_add_lora = await requests_processing_time(
|
||||||
llm, warmup_run_requests)
|
llm, warmup_run_requests)
|
||||||
|
|
||||||
|
|||||||
@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files):
|
|||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
# also test odd max_num_seqs
|
# also test odd max_num_seqs
|
||||||
max_num_seqs=13,
|
max_num_seqs=13,
|
||||||
max_loras=4,
|
max_loras=4)
|
||||||
enable_chunked_prefill=True)
|
|
||||||
generate_and_test(llm, sql_lora_files)
|
generate_and_test(llm, sql_lora_files)
|
||||||
|
|
||||||
|
|
||||||
@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files):
|
|||||||
max_num_seqs=16,
|
max_num_seqs=16,
|
||||||
max_loras=4,
|
max_loras=4,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
enable_chunked_prefill=True,
|
|
||||||
)
|
)
|
||||||
generate_and_test(llm, sql_lora_files)
|
generate_and_test(llm, sql_lora_files)
|
||||||
|
|
||||||
@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
|||||||
max_loras=4,
|
max_loras=4,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
fully_sharded_loras=True,
|
fully_sharded_loras=True,
|
||||||
enable_chunked_prefill=True,
|
|
||||||
)
|
)
|
||||||
generate_and_test(llm, sql_lora_files)
|
generate_and_test(llm, sql_lora_files)
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
|||||||
WorkerLoRAManager)
|
WorkerLoRAManager)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .utils import create_peft_lora
|
||||||
|
|
||||||
EMBEDDING_MODULES = {
|
EMBEDDING_MODULES = {
|
||||||
"embed_tokens": "input_embeddings",
|
"embed_tokens": "input_embeddings",
|
||||||
"lm_head": "output_embeddings",
|
"lm_head": "output_embeddings",
|
||||||
@ -35,17 +37,6 @@ DEVICES = ([
|
|||||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
|
||||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""
|
|
||||||
Some tests depend on V0 internals. Since both V0 and V1 use the same
|
|
||||||
LoRAModelManager it is okay to just test V0.
|
|
||||||
"""
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv('VLLM_USE_V1', '0')
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
def test_from_lora_tensors(sql_lora_files, device):
|
def test_from_lora_tensors(sql_lora_files, device):
|
||||||
tensors = load_file(
|
tensors = load_file(
|
||||||
@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
max_loras=2,
|
max_loras=2,
|
||||||
lora_dtype=DEFAULT_DTYPE),
|
lora_dtype=DEFAULT_DTYPE),
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
assert all(x is None for x in manager.lora_index_to_id)
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
# Add up to capacity
|
# Add up to capacity
|
||||||
@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||||
sql_lora_files, device):
|
tmp_path):
|
||||||
lora_config = LoRAConfig(max_lora_rank=8,
|
lora_config = LoRAConfig(max_lora_rank=8,
|
||||||
max_cpu_loras=4,
|
max_cpu_loras=4,
|
||||||
max_loras=4,
|
max_loras=4,
|
||||||
lora_dtype=DEFAULT_DTYPE)
|
lora_dtype=DEFAULT_DTYPE)
|
||||||
|
|
||||||
|
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||||
|
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||||
|
create_peft_lora(
|
||||||
|
dummy_model,
|
||||||
|
save_dir=dummy_lora_files,
|
||||||
|
target_modules=["layer1.dense1", "dense2"],
|
||||||
|
lora_dtype=DEFAULT_DTYPE,
|
||||||
|
)
|
||||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
4, 2,
|
||||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
|
||||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||||
worker_adapter_manager.create_lora_manager(
|
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||||
llama_2_7b_model_extra_embeddings)
|
|
||||||
|
|
||||||
mapping = LoRAMapping([], [])
|
mapping = LoRAMapping([], [])
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("2", 2, sql_lora_files)
|
LoRARequest("2", 2, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("3", 3, sql_lora_files),
|
LoRARequest("3", 3, dummy_lora_files),
|
||||||
LoRARequest("4", 4, sql_lora_files)
|
LoRARequest("4", 4, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("2", 2, sql_lora_files),
|
LoRARequest("2", 2, dummy_lora_files),
|
||||||
LoRARequest("5", 5, sql_lora_files)
|
LoRARequest("5", 5, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("1", 1, sql_lora_files)
|
LoRARequest("1", 1, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("6", 6, sql_lora_files),
|
LoRARequest("6", 6, dummy_lora_files),
|
||||||
LoRARequest("7", 7, sql_lora_files),
|
LoRARequest("7", 7, dummy_lora_files),
|
||||||
LoRARequest("8", 8, sql_lora_files)
|
LoRARequest("8", 8, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
# Over capacity
|
# Over capacity
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("10", 10, sql_lora_files),
|
LoRARequest("10", 10, dummy_lora_files),
|
||||||
LoRARequest("11", 11, sql_lora_files),
|
LoRARequest("11", 11, dummy_lora_files),
|
||||||
LoRARequest("12", 12, sql_lora_files),
|
LoRARequest("12", 12, dummy_lora_files),
|
||||||
LoRARequest("13", 13, sql_lora_files),
|
LoRARequest("13", 13, dummy_lora_files),
|
||||||
LoRARequest("14", 14, sql_lora_files)
|
LoRARequest("14", 14, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
|
|
||||||
assert worker_adapter_manager.device == device
|
assert worker_adapter_manager.device == device
|
||||||
@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||||
sql_lora_files, device):
|
tmp_path):
|
||||||
# Should remove every LoRA not specified in the request.
|
# Should remove every LoRA not specified in the request.
|
||||||
lora_config = LoRAConfig(max_lora_rank=8,
|
lora_config = LoRAConfig(max_lora_rank=8,
|
||||||
max_cpu_loras=4,
|
max_cpu_loras=4,
|
||||||
max_loras=4,
|
max_loras=4,
|
||||||
lora_dtype=DEFAULT_DTYPE)
|
lora_dtype=DEFAULT_DTYPE)
|
||||||
worker_adapter_manager = WorkerLoRAManager(
|
worker_adapter_manager = WorkerLoRAManager(
|
||||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
4, 2, dummy_model_gate_up.unpadded_vocab_size -
|
||||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||||
worker_adapter_manager.create_lora_manager(
|
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||||
llama_2_7b_model_extra_embeddings)
|
|
||||||
|
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||||
|
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||||
|
create_peft_lora(
|
||||||
|
dummy_model_gate_up,
|
||||||
|
save_dir=dummy_lora_files,
|
||||||
|
target_modules=["layer1.dense1", "dense2"],
|
||||||
|
lora_dtype=DEFAULT_DTYPE,
|
||||||
|
)
|
||||||
|
|
||||||
mapping = LoRAMapping([], [])
|
mapping = LoRAMapping([], [])
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("2", 2, sql_lora_files)
|
LoRARequest("2", 2, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("3", 3, sql_lora_files),
|
LoRARequest("3", 3, dummy_lora_files),
|
||||||
LoRARequest("4", 4, sql_lora_files)
|
LoRARequest("4", 4, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("2", 2, sql_lora_files),
|
LoRARequest("2", 2, dummy_lora_files),
|
||||||
LoRARequest("5", 5, sql_lora_files)
|
LoRARequest("5", 5, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("1", 1, sql_lora_files),
|
LoRARequest("1", 1, dummy_lora_files),
|
||||||
LoRARequest("1", 1, sql_lora_files)
|
LoRARequest("1", 1, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {1}
|
assert worker_adapter_manager.list_adapters() == {1}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||||
@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
||||||
|
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("6", 6, sql_lora_files),
|
LoRARequest("6", 6, dummy_lora_files),
|
||||||
LoRARequest("7", 7, sql_lora_files),
|
LoRARequest("7", 7, dummy_lora_files),
|
||||||
LoRARequest("8", 8, sql_lora_files)
|
LoRARequest("8", 8, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
||||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
||||||
@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
|||||||
# Over capacity
|
# Over capacity
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
worker_adapter_manager.set_active_adapters([
|
worker_adapter_manager.set_active_adapters([
|
||||||
LoRARequest("10", 10, sql_lora_files),
|
LoRARequest("10", 10, dummy_lora_files),
|
||||||
LoRARequest("11", 11, sql_lora_files),
|
LoRARequest("11", 11, dummy_lora_files),
|
||||||
LoRARequest("12", 12, sql_lora_files),
|
LoRARequest("12", 12, dummy_lora_files),
|
||||||
LoRARequest("13", 13, sql_lora_files),
|
LoRARequest("13", 13, dummy_lora_files),
|
||||||
LoRARequest("14", 14, sql_lora_files)
|
LoRARequest("14", 14, dummy_lora_files)
|
||||||
], mapping)
|
], mapping)
|
||||||
|
|
||||||
assert worker_adapter_manager.device == device
|
assert worker_adapter_manager.device == device
|
||||||
|
|||||||
@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
|
|||||||
max_loras=4,
|
max_loras=4,
|
||||||
distributed_executor_backend="ray",
|
distributed_executor_backend="ray",
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
enable_chunked_prefill=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_lora_output = [
|
expected_lora_output = [
|
||||||
|
|||||||
@ -4,17 +4,14 @@
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Union
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.lora.models import LoRAMapping
|
from vllm.lora.models import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.v1.worker.gpu_worker import Worker as V1Worker
|
from vllm.v1.worker.gpu_worker import Worker
|
||||||
from vllm.worker.worker import Worker
|
|
||||||
|
|
||||||
NUM_LORAS = 16
|
NUM_LORAS = 16
|
||||||
|
|
||||||
@ -22,18 +19,11 @@ NUM_LORAS = 16
|
|||||||
@patch.dict(os.environ, {"RANK": "0"})
|
@patch.dict(os.environ, {"RANK": "0"})
|
||||||
def test_worker_apply_lora(sql_lora_files):
|
def test_worker_apply_lora(sql_lora_files):
|
||||||
|
|
||||||
def set_active_loras(worker: Union[Worker, V1Worker],
|
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
|
||||||
lora_requests: list[LoRARequest]):
|
|
||||||
lora_mapping = LoRAMapping([], [])
|
lora_mapping = LoRAMapping([], [])
|
||||||
if isinstance(worker, Worker):
|
|
||||||
# v0 case
|
|
||||||
worker.model_runner.set_active_loras(lora_requests, lora_mapping)
|
|
||||||
else:
|
|
||||||
# v1 case
|
|
||||||
worker.model_runner.lora_manager.set_active_adapters(
|
|
||||||
lora_requests, lora_mapping)
|
|
||||||
|
|
||||||
worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker
|
worker.model_runner.lora_manager.set_active_adapters(
|
||||||
|
lora_requests, lora_mapping)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
model_config=ModelConfig(
|
model_config=ModelConfig(
|
||||||
@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
|
|||||||
max_cpu_loras=NUM_LORAS,
|
max_cpu_loras=NUM_LORAS,
|
||||||
max_loras=NUM_LORAS),
|
max_loras=NUM_LORAS),
|
||||||
)
|
)
|
||||||
worker = worker_cls(
|
worker = Worker(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
|
|
||||||
@ -340,3 +343,76 @@ def generate_data_for_nslices(
|
|||||||
seq_len_tensor,
|
seq_len_tensor,
|
||||||
indices,
|
indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_peft_lora(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
save_dir: str,
|
||||||
|
target_modules: list[str],
|
||||||
|
rank: int = 8,
|
||||||
|
alpha: int = 16,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
lora_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
lora_weights = {}
|
||||||
|
adapter_config = {
|
||||||
|
"peft_type": "LORA",
|
||||||
|
"auto_mapping": None,
|
||||||
|
"base_model_name_or_path": "dummy_model",
|
||||||
|
"revision": None,
|
||||||
|
"task_type": "CAUSAL_LM",
|
||||||
|
"inference_mode": False,
|
||||||
|
"r": rank,
|
||||||
|
"lora_alpha": alpha,
|
||||||
|
"lora_dropout": dropout,
|
||||||
|
"fan_in_fan_out": False,
|
||||||
|
"bias": "none",
|
||||||
|
"modules_to_save": None,
|
||||||
|
"init_lora_weights": True,
|
||||||
|
"layers_to_transform": None,
|
||||||
|
"layers_pattern": None,
|
||||||
|
"target_modules": target_modules,
|
||||||
|
"exclude_modules": None,
|
||||||
|
"use_rslora": False,
|
||||||
|
"use_dora": False,
|
||||||
|
"loftq_config": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
for module_name in target_modules:
|
||||||
|
|
||||||
|
module = model
|
||||||
|
for attr in module_name.split("."):
|
||||||
|
module = getattr(module, attr)
|
||||||
|
|
||||||
|
if hasattr(module, "input_size") and hasattr(module, "output_size"):
|
||||||
|
|
||||||
|
in_features = module.input_size
|
||||||
|
out_features = module.output_size
|
||||||
|
|
||||||
|
elif hasattr(module, "embedding_dim") and hasattr(
|
||||||
|
module, "num_embeddings"):
|
||||||
|
# ParallelLMHead
|
||||||
|
in_features = module.embedding_dim
|
||||||
|
out_features = module.num_embeddings
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to determine dimensions for module {module_name}")
|
||||||
|
|
||||||
|
lora_A = torch.randn(rank, in_features, dtype=lora_dtype)
|
||||||
|
|
||||||
|
torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)
|
||||||
|
|
||||||
|
lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)
|
||||||
|
|
||||||
|
# PEFT style
|
||||||
|
lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
|
||||||
|
lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B
|
||||||
|
|
||||||
|
config_path = os.path.join(save_dir, "adapter_config.json")
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(adapter_config, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
weights_path = os.path.join(save_dir, "adapter_model.safetensors")
|
||||||
|
save_file(lora_weights, weights_path)
|
||||||
|
|
||||||
|
return lora_weights
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user