mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:35:01 +08:00
[Bugfix] Validate lora adapters to avoid crashing server (#11727)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
cf5f000d21
commit
ac2f3f7fee
269
tests/entrypoints/openai/test_lora_adapters.py
Normal file
269
tests/entrypoints/openai/test_lora_adapters.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
# downloading lora to test lora requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
||||||
|
# generation quality here
|
||||||
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def zephyr_lora_files():
|
||||||
|
return snapshot_download(repo_id=LORA_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server_with_lora_modules_json(zephyr_lora_files):
|
||||||
|
# Define the json format LoRA module configurations
|
||||||
|
lora_module_1 = {
|
||||||
|
"name": "zephyr-lora",
|
||||||
|
"path": zephyr_lora_files,
|
||||||
|
"base_model_name": MODEL_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
lora_module_2 = {
|
||||||
|
"name": "zephyr-lora2",
|
||||||
|
"path": zephyr_lora_files,
|
||||||
|
"base_model_name": MODEL_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
# lora config below
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
json.dumps(lora_module_1),
|
||||||
|
json.dumps(lora_module_2),
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
|
"--max-cpu-loras",
|
||||||
|
"2",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"64",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Enable the /v1/load_lora_adapter endpoint
|
||||||
|
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server_with_lora_modules_json):
|
||||||
|
async with server_with_lora_modules_json.get_async_client(
|
||||||
|
) as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_lora_lineage(client: openai.AsyncOpenAI,
|
||||||
|
zephyr_lora_files):
|
||||||
|
models = await client.models.list()
|
||||||
|
models = models.data
|
||||||
|
served_model = models[0]
|
||||||
|
lora_models = models[1:]
|
||||||
|
assert served_model.id == MODEL_NAME
|
||||||
|
assert served_model.root == MODEL_NAME
|
||||||
|
assert served_model.parent is None
|
||||||
|
assert all(lora_model.root == zephyr_lora_files
|
||||||
|
for lora_model in lora_models)
|
||||||
|
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||||
|
assert lora_models[0].id == "zephyr-lora"
|
||||||
|
assert lora_models[1].id == "zephyr-lora2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
|
||||||
|
zephyr_lora_files):
|
||||||
|
|
||||||
|
response = await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "zephyr-lora-3",
|
||||||
|
"lora_path": zephyr_lora_files
|
||||||
|
})
|
||||||
|
# Ensure adapter loads before querying /models
|
||||||
|
assert "success" in response
|
||||||
|
|
||||||
|
models = await client.models.list()
|
||||||
|
models = models.data
|
||||||
|
dynamic_lora_model = models[-1]
|
||||||
|
assert dynamic_lora_model.root == zephyr_lora_files
|
||||||
|
assert dynamic_lora_model.parent == MODEL_NAME
|
||||||
|
assert dynamic_lora_model.id == "zephyr-lora-3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
|
||||||
|
with pytest.raises(openai.NotFoundError):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "notfound",
|
||||||
|
"lora_path": "/not/an/adapter"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
|
||||||
|
tmp_path):
|
||||||
|
invalid_files = tmp_path / "invalid_files"
|
||||||
|
invalid_files.mkdir()
|
||||||
|
(invalid_files / "adapter_config.json").write_text("this is not json")
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "invalid-json",
|
||||||
|
"lora_path": str(invalid_files)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
|
||||||
|
tmp_path, zephyr_lora_files):
|
||||||
|
invalid_rank = tmp_path / "invalid_rank"
|
||||||
|
|
||||||
|
# Copy adapter from zephyr_lora_files to invalid_rank
|
||||||
|
shutil.copytree(zephyr_lora_files, invalid_rank)
|
||||||
|
|
||||||
|
with open(invalid_rank / "adapter_config.json") as f:
|
||||||
|
adapter_config = json.load(f)
|
||||||
|
|
||||||
|
print(adapter_config)
|
||||||
|
|
||||||
|
# assert False
|
||||||
|
|
||||||
|
# Change rank to invalid value
|
||||||
|
adapter_config["r"] = 1024
|
||||||
|
with open(invalid_rank / "adapter_config.json", "w") as f:
|
||||||
|
json.dump(adapter_config, f)
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError,
|
||||||
|
match="is greater than max_lora_rank"):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "invalid-json",
|
||||||
|
"lora_path": str(invalid_rank)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
|
||||||
|
zephyr_lora_files):
|
||||||
|
"""Validate that many loras can be dynamically registered and inferenced
|
||||||
|
with concurrently"""
|
||||||
|
|
||||||
|
# This test file configures the server with --max-cpu-loras=2 and this test
|
||||||
|
# will concurrently load 10 adapters, so it should flex the LRU cache
|
||||||
|
async def load_and_run_adapter(adapter_name: str):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": adapter_name,
|
||||||
|
"lora_path": str(zephyr_lora_files)
|
||||||
|
})
|
||||||
|
for _ in range(3):
|
||||||
|
await client.completions.create(
|
||||||
|
model=adapter_name,
|
||||||
|
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_tasks = []
|
||||||
|
for i in range(10):
|
||||||
|
lora_tasks.append(
|
||||||
|
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||||
|
|
||||||
|
results, _ = await asyncio.wait(lora_tasks)
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loading_invalid_adapters_does_not_break_others(
|
||||||
|
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files):
|
||||||
|
|
||||||
|
invalid_files = tmp_path / "invalid_files"
|
||||||
|
invalid_files.mkdir()
|
||||||
|
(invalid_files / "adapter_config.json").write_text("this is not json")
|
||||||
|
|
||||||
|
stop_good_requests_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def run_good_requests(client):
|
||||||
|
# Run chat completions requests until event set
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
while not stop_good_requests_event.is_set():
|
||||||
|
try:
|
||||||
|
batch = await client.completions.create(
|
||||||
|
model="zephyr-lora",
|
||||||
|
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
|
results.append(batch)
|
||||||
|
except Exception as e:
|
||||||
|
results.append(e)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Create task to run good requests
|
||||||
|
good_task = asyncio.create_task(run_good_requests(client))
|
||||||
|
|
||||||
|
# Run a bunch of bad adapter loads
|
||||||
|
for _ in range(25):
|
||||||
|
with suppress(openai.NotFoundError):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "notfound",
|
||||||
|
"lora_path": "/not/an/adapter"
|
||||||
|
})
|
||||||
|
for _ in range(25):
|
||||||
|
with suppress(openai.BadRequestError):
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "invalid",
|
||||||
|
"lora_path": str(invalid_files)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Ensure all the running requests with lora adapters succeeded
|
||||||
|
stop_good_requests_event.set()
|
||||||
|
results = await good_task
|
||||||
|
for r in results:
|
||||||
|
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||||
|
|
||||||
|
# Ensure we can load another adapter and run it
|
||||||
|
await client.post("load_lora_adapter",
|
||||||
|
cast_to=str,
|
||||||
|
body={
|
||||||
|
"lora_name": "valid",
|
||||||
|
"lora_path": zephyr_lora_files
|
||||||
|
})
|
||||||
|
await client.completions.create(
|
||||||
|
model="valid",
|
||||||
|
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||||
|
max_tokens=5,
|
||||||
|
)
|
||||||
@ -1,109 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
import openai # use the official client for correctness check
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
# downloading lora to test lora requests
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
|
||||||
|
|
||||||
# any model with a chat template should work here
|
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
|
||||||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
|
||||||
# generation quality here
|
|
||||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def zephyr_lora_files():
|
|
||||||
return snapshot_download(repo_id=LORA_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def server_with_lora_modules_json(zephyr_lora_files):
|
|
||||||
# Define the json format LoRA module configurations
|
|
||||||
lora_module_1 = {
|
|
||||||
"name": "zephyr-lora",
|
|
||||||
"path": zephyr_lora_files,
|
|
||||||
"base_model_name": MODEL_NAME
|
|
||||||
}
|
|
||||||
|
|
||||||
lora_module_2 = {
|
|
||||||
"name": "zephyr-lora2",
|
|
||||||
"path": zephyr_lora_files,
|
|
||||||
"base_model_name": MODEL_NAME
|
|
||||||
}
|
|
||||||
|
|
||||||
args = [
|
|
||||||
# use half precision for speed and memory savings in CI environment
|
|
||||||
"--dtype",
|
|
||||||
"bfloat16",
|
|
||||||
"--max-model-len",
|
|
||||||
"8192",
|
|
||||||
"--enforce-eager",
|
|
||||||
# lora config below
|
|
||||||
"--enable-lora",
|
|
||||||
"--lora-modules",
|
|
||||||
json.dumps(lora_module_1),
|
|
||||||
json.dumps(lora_module_2),
|
|
||||||
"--max-lora-rank",
|
|
||||||
"64",
|
|
||||||
"--max-cpu-loras",
|
|
||||||
"2",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"64",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Enable the /v1/load_lora_adapter endpoint
|
|
||||||
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
|
||||||
yield remote_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def client_for_lora_lineage(server_with_lora_modules_json):
|
|
||||||
async with server_with_lora_modules_json.get_async_client(
|
|
||||||
) as async_client:
|
|
||||||
yield async_client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
|
||||||
zephyr_lora_files):
|
|
||||||
models = await client_for_lora_lineage.models.list()
|
|
||||||
models = models.data
|
|
||||||
served_model = models[0]
|
|
||||||
lora_models = models[1:]
|
|
||||||
assert served_model.id == MODEL_NAME
|
|
||||||
assert served_model.root == MODEL_NAME
|
|
||||||
assert served_model.parent is None
|
|
||||||
assert all(lora_model.root == zephyr_lora_files
|
|
||||||
for lora_model in lora_models)
|
|
||||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
|
||||||
assert lora_models[0].id == "zephyr-lora"
|
|
||||||
assert lora_models[1].id == "zephyr-lora2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_dynamic_lora_lineage(
|
|
||||||
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
|
|
||||||
|
|
||||||
response = await client_for_lora_lineage.post("load_lora_adapter",
|
|
||||||
cast_to=str,
|
|
||||||
body={
|
|
||||||
"lora_name":
|
|
||||||
"zephyr-lora-3",
|
|
||||||
"lora_path":
|
|
||||||
zephyr_lora_files
|
|
||||||
})
|
|
||||||
# Ensure adapter loads before querying /models
|
|
||||||
assert "success" in response
|
|
||||||
|
|
||||||
models = await client_for_lora_lineage.models.list()
|
|
||||||
models = models.data
|
|
||||||
dynamic_lora_model = models[-1]
|
|
||||||
assert dynamic_lora_model.root == zephyr_lora_files
|
|
||||||
assert dynamic_lora_model.parent == MODEL_NAME
|
|
||||||
assert dynamic_lora_model.id == "zephyr-lora-3"
|
|
||||||
@ -52,7 +52,7 @@ async def _async_serving_chat_init():
|
|||||||
engine = MockEngine()
|
engine = MockEngine()
|
||||||
model_config = await engine.get_model_config()
|
model_config = await engine.get_model_config()
|
||||||
|
|
||||||
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
|
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
|
||||||
serving_completion = OpenAIServingChat(engine,
|
serving_completion = OpenAIServingChat(engine,
|
||||||
model_config,
|
model_config,
|
||||||
models,
|
models,
|
||||||
@ -73,7 +73,8 @@ def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
models = OpenAIServingModels(engine_client=mock_engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
model_config=MockModelConfig())
|
model_config=MockModelConfig())
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
MockModelConfig(),
|
MockModelConfig(),
|
||||||
@ -116,7 +117,8 @@ def test_serving_chat_could_load_correct_generation_config():
|
|||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
# Initialize the serving chat
|
# Initialize the serving chat
|
||||||
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
models = OpenAIServingModels(engine_client=mock_engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
model_config=mock_model_config)
|
model_config=mock_model_config)
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
mock_model_config,
|
mock_model_config,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
@ -21,13 +22,16 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
|||||||
|
|
||||||
async def _async_serving_models_init() -> OpenAIServingModels:
|
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||||
mock_model_config = MagicMock(spec=ModelConfig)
|
mock_model_config = MagicMock(spec=ModelConfig)
|
||||||
|
mock_engine_client = MagicMock(spec=EngineClient)
|
||||||
# Set the max_model_len attribute to avoid missing attribute
|
# Set the max_model_len attribute to avoid missing attribute
|
||||||
mock_model_config.max_model_len = 2048
|
mock_model_config.max_model_len = 2048
|
||||||
|
|
||||||
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
model_config=mock_model_config,
|
model_config=mock_model_config,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None)
|
prompt_adapters=None)
|
||||||
|
await serving_models.init_static_loras()
|
||||||
|
|
||||||
return serving_models
|
return serving_models
|
||||||
|
|
||||||
@ -113,5 +117,5 @@ async def test_unload_lora_adapter_not_found():
|
|||||||
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
||||||
response = await serving_models.unload_lora_adapter(request)
|
response = await serving_models.unload_lora_adapter(request)
|
||||||
assert isinstance(response, ErrorResponse)
|
assert isinstance(response, ErrorResponse)
|
||||||
assert response.type == "InvalidUserInput"
|
assert response.type == "NotFoundError"
|
||||||
assert response.code == HTTPStatus.BAD_REQUEST
|
assert response.code == HTTPStatus.NOT_FOUND
|
||||||
|
|||||||
@ -1,6 +1,3 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -10,16 +7,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B"
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shutdown_on_engine_failure(tmp_path):
|
async def test_shutdown_on_engine_failure():
|
||||||
# Use a bad adapter to crash the engine
|
|
||||||
# (This test will fail when that bug is fixed)
|
|
||||||
adapter_path = tmp_path / "bad_adapter"
|
|
||||||
os.mkdir(adapter_path)
|
|
||||||
with open(adapter_path / "adapter_model_config.json", "w") as f:
|
|
||||||
json.dump({"not": "real"}, f)
|
|
||||||
with open(adapter_path / "adapter_model.safetensors", "wb") as f:
|
|
||||||
f.write(b"this is fake")
|
|
||||||
|
|
||||||
# dtype, max-len etc set so that this can run in CI
|
# dtype, max-len etc set so that this can run in CI
|
||||||
args = [
|
args = [
|
||||||
"--dtype",
|
"--dtype",
|
||||||
@ -29,9 +17,6 @@ async def test_shutdown_on_engine_failure(tmp_path):
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-num-seqs",
|
"--max-num-seqs",
|
||||||
"128",
|
"128",
|
||||||
"--enable-lora",
|
|
||||||
"--lora-modules",
|
|
||||||
f"bad-adapter={tmp_path / 'bad_adapter'}",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -39,9 +24,13 @@ async def test_shutdown_on_engine_failure(tmp_path):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
(openai.APIConnectionError, openai.InternalServerError)):
|
(openai.APIConnectionError, openai.InternalServerError)):
|
||||||
# This crashes the engine
|
# Asking for lots of prompt logprobs will currently crash the
|
||||||
await client.completions.create(model="bad-adapter",
|
# engine. This may change in the future when that bug is fixed
|
||||||
prompt="Hello, my name is")
|
prompt = "Hello " * 4000
|
||||||
|
await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=prompt,
|
||||||
|
extra_body={"prompt_logprobs": 10})
|
||||||
|
|
||||||
# Now the server should shut down
|
# Now the server should shut down
|
||||||
return_code = remote_server.proc.wait(timeout=8)
|
return_code = remote_server.proc.wait(timeout=8)
|
||||||
|
|||||||
@ -1257,6 +1257,10 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
else:
|
else:
|
||||||
self.engine.model_executor._run_workers("stop_profile")
|
self.engine.model_executor._run_workers("stop_profile")
|
||||||
|
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
|
self.engine.add_lora(lora_request)
|
||||||
|
|
||||||
|
|
||||||
# TODO(v1): Remove this class proxy when V1 goes default.
|
# TODO(v1): Remove this class proxy when V1 goes default.
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Mapping, Optional, Union, overload
|
from typing import List, Mapping, Optional, Union, overload
|
||||||
|
|
||||||
@ -120,10 +121,23 @@ class RPCUProfileRequest(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
@dataclass
|
||||||
RPCUProfileRequest]
|
class RPCLoadAdapterRequest:
|
||||||
|
lora_request: LoRARequest
|
||||||
|
# Set the default value of request_id to a new UUID
|
||||||
|
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
|
||||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
|
|
||||||
|
@dataclass
|
||||||
|
class RPCAdapterLoadedResponse:
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||||
|
RPCUProfileRequest, RPCLoadAdapterRequest]
|
||||||
|
|
||||||
|
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
|
||||||
|
RPCError]
|
||||||
|
|
||||||
|
|
||||||
def ENGINE_DEAD_ERROR(
|
def ENGINE_DEAD_ERROR(
|
||||||
|
|||||||
@ -25,8 +25,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
RPCError, RPCProcessRequest,
|
RPCAdapterLoadedResponse, RPCError,
|
||||||
RPCStartupRequest, RPCStartupResponse,
|
RPCLoadAdapterRequest,
|
||||||
|
RPCProcessRequest, RPCStartupRequest,
|
||||||
|
RPCStartupResponse,
|
||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -240,17 +242,22 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
queue = self.output_queues.get(request_id)
|
queue = self.output_queues.get(request_id)
|
||||||
if queue is not None:
|
if queue is not None:
|
||||||
queue.put_nowait(exception)
|
queue.put_nowait(exception)
|
||||||
|
# Put each output into the appropriate queue.
|
||||||
|
elif isinstance(request_outputs, RPCAdapterLoadedResponse):
|
||||||
|
self._add_output(request_outputs)
|
||||||
else:
|
else:
|
||||||
# Put each output into the appropriate steam.
|
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
queue = self.output_queues.get(
|
self._add_output(request_output)
|
||||||
request_output.request_id)
|
|
||||||
if queue is not None:
|
|
||||||
queue.put_nowait(request_output)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
||||||
|
|
||||||
|
def _add_output(self, request_output: Union[RequestOutput,
|
||||||
|
RPCAdapterLoadedResponse]):
|
||||||
|
queue = self.output_queues.get(request_output.request_id)
|
||||||
|
if queue is not None:
|
||||||
|
queue.put_nowait(request_output)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""Setup the client before it starts sending server requests."""
|
"""Setup the client before it starts sending server requests."""
|
||||||
|
|
||||||
@ -659,3 +666,24 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
|
|
||||||
await self._send_one_way_rpc_request(
|
await self._send_one_way_rpc_request(
|
||||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||||
|
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
|
# Uses the same I/O as generate requests
|
||||||
|
request = RPCLoadAdapterRequest(lora_request)
|
||||||
|
|
||||||
|
# Create output queue for this requests.
|
||||||
|
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
|
||||||
|
self.output_queues[request.request_id] = queue
|
||||||
|
|
||||||
|
# Send the request
|
||||||
|
request_bytes = pickle.dumps(request)
|
||||||
|
await self.input_socket.send_multipart((request_bytes, ), copy=False)
|
||||||
|
|
||||||
|
# Wait for the response
|
||||||
|
request_output = await queue.get()
|
||||||
|
self.output_queues.pop(request.request_id)
|
||||||
|
|
||||||
|
# Raise on error, otherwise happily return None
|
||||||
|
if isinstance(request_output, BaseException):
|
||||||
|
raise request_output
|
||||||
|
|||||||
@ -14,8 +14,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
RPCError, RPCProcessRequest,
|
RPCAdapterLoadedResponse, RPCError,
|
||||||
RPCStartupRequest, RPCStartupResponse,
|
RPCLoadAdapterRequest,
|
||||||
|
RPCProcessRequest, RPCStartupRequest,
|
||||||
|
RPCStartupResponse,
|
||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
@ -234,6 +236,8 @@ class MQLLMEngine:
|
|||||||
self.start_profile()
|
self.start_profile()
|
||||||
else:
|
else:
|
||||||
self.stop_profile()
|
self.stop_profile()
|
||||||
|
elif isinstance(request, RPCLoadAdapterRequest):
|
||||||
|
self._handle_load_adapter_request(request)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown RPCRequest Type: "
|
raise ValueError("Unknown RPCRequest Type: "
|
||||||
f"{type(request)}")
|
f"{type(request)}")
|
||||||
@ -284,6 +288,19 @@ class MQLLMEngine:
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Aborted request %s.", request.request_id)
|
logger.info("Aborted request %s.", request.request_id)
|
||||||
|
|
||||||
|
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
|
||||||
|
try:
|
||||||
|
self.engine.add_lora(request.lora_request)
|
||||||
|
except BaseException as e:
|
||||||
|
# Send back an error if the adater fails to load
|
||||||
|
rpc_err = RPCError(request_id=request.request_id,
|
||||||
|
is_engine_errored=False,
|
||||||
|
exception=e)
|
||||||
|
self._send_outputs(rpc_err)
|
||||||
|
# Otherwise, send back the successful load message
|
||||||
|
self._send_outputs(
|
||||||
|
RPCAdapterLoadedResponse(request_id=request.request_id))
|
||||||
|
|
||||||
def _health_check(self):
|
def _health_check(self):
|
||||||
# Send unhealthy if engine has already errored
|
# Send unhealthy if engine has already errored
|
||||||
if self._errored_with is not None:
|
if self._errored_with is not None:
|
||||||
@ -296,7 +313,11 @@ class MQLLMEngine:
|
|||||||
self._send_unhealthy(e)
|
self._send_unhealthy(e)
|
||||||
|
|
||||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||||
"""Send List of RequestOutput to RPCClient."""
|
"""Send outputs back to the engine client. These can be:
|
||||||
|
- Exceptions
|
||||||
|
- A list of generation outputs
|
||||||
|
- A response from loading a lora adapter
|
||||||
|
"""
|
||||||
if outputs:
|
if outputs:
|
||||||
try:
|
try:
|
||||||
from ray.exceptions import RayTaskError
|
from ray.exceptions import RayTaskError
|
||||||
|
|||||||
@ -270,3 +270,8 @@ class EngineClient(ABC):
|
|||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
"""Start profiling the engine"""
|
"""Start profiling the engine"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
|
...
|
||||||
|
|||||||
@ -662,7 +662,7 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def init_app_state(
|
async def init_app_state(
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
state: State,
|
state: State,
|
||||||
@ -690,12 +690,13 @@ def init_app_state(
|
|||||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||||
|
|
||||||
state.openai_serving_models = OpenAIServingModels(
|
state.openai_serving_models = OpenAIServingModels(
|
||||||
|
engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
)
|
)
|
||||||
# TODO: The chat template is now broken for lora adapters :(
|
await state.openai_serving_models.init_static_loras()
|
||||||
state.openai_serving_chat = OpenAIServingChat(
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
@ -794,7 +795,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
model_config = await engine_client.get_model_config()
|
model_config = await engine_client.get_model_config()
|
||||||
init_app_state(engine_client, model_config, app.state, args)
|
await init_app_state(engine_client, model_config, app.state, args)
|
||||||
|
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
|
|||||||
@ -215,6 +215,7 @@ async def main(args):
|
|||||||
|
|
||||||
# Create the openai serving objects.
|
# Create the openai serving objects.
|
||||||
openai_serving_models = OpenAIServingModels(
|
openai_serving_models = OpenAIServingModels(
|
||||||
|
engine_client=engine,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
|
|||||||
@ -5,15 +5,19 @@ from http import HTTPStatus
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
ModelCard, ModelList,
|
ModelCard, ModelList,
|
||||||
ModelPermission,
|
ModelPermission,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.utils import AtomicCounter
|
from vllm.utils import AtomicCounter
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelPath:
|
class BaseModelPath:
|
||||||
@ -45,6 +49,7 @@ class OpenAIServingModels:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
base_model_paths: List[BaseModelPath],
|
base_model_paths: List[BaseModelPath],
|
||||||
*,
|
*,
|
||||||
@ -55,20 +60,11 @@ class OpenAIServingModels:
|
|||||||
|
|
||||||
self.base_model_paths = base_model_paths
|
self.base_model_paths = base_model_paths
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
self.engine_client = engine_client
|
||||||
|
|
||||||
|
self.static_lora_modules = lora_modules
|
||||||
|
self.lora_requests: List[LoRARequest] = []
|
||||||
self.lora_id_counter = AtomicCounter(0)
|
self.lora_id_counter = AtomicCounter(0)
|
||||||
self.lora_requests = []
|
|
||||||
if lora_modules is not None:
|
|
||||||
self.lora_requests = [
|
|
||||||
LoRARequest(lora_name=lora.name,
|
|
||||||
lora_int_id=i,
|
|
||||||
lora_path=lora.path,
|
|
||||||
base_model_name=lora.base_model_name
|
|
||||||
if lora.base_model_name
|
|
||||||
and self.is_base_model(lora.base_model_name) else
|
|
||||||
self.base_model_paths[0].name)
|
|
||||||
for i, lora in enumerate(lora_modules, start=1)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.prompt_adapter_requests = []
|
self.prompt_adapter_requests = []
|
||||||
if prompt_adapters is not None:
|
if prompt_adapters is not None:
|
||||||
@ -84,6 +80,19 @@ class OpenAIServingModels:
|
|||||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||||
|
|
||||||
|
async def init_static_loras(self):
|
||||||
|
"""Loads all static LoRA modules.
|
||||||
|
Raises if any fail to load"""
|
||||||
|
if self.static_lora_modules is None:
|
||||||
|
return
|
||||||
|
for lora in self.static_lora_modules:
|
||||||
|
load_request = LoadLoraAdapterRequest(lora_path=lora.path,
|
||||||
|
lora_name=lora.name)
|
||||||
|
load_result = await self.load_lora_adapter(
|
||||||
|
request=load_request, base_model_name=lora.base_model_name)
|
||||||
|
if isinstance(load_result, ErrorResponse):
|
||||||
|
raise ValueError(load_result.message)
|
||||||
|
|
||||||
def is_base_model(self, model_name):
|
def is_base_model(self, model_name):
|
||||||
return any(model.name == model_name for model in self.base_model_paths)
|
return any(model.name == model_name for model in self.base_model_paths)
|
||||||
|
|
||||||
@ -129,17 +138,47 @@ class OpenAIServingModels:
|
|||||||
|
|
||||||
async def load_lora_adapter(
|
async def load_lora_adapter(
|
||||||
self,
|
self,
|
||||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
request: LoadLoraAdapterRequest,
|
||||||
|
base_model_name: Optional[str] = None
|
||||||
|
) -> Union[ErrorResponse, str]:
|
||||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
lora_name, lora_path = request.lora_name, request.lora_path
|
lora_name, lora_path = request.lora_name, request.lora_path
|
||||||
unique_id = self.lora_id_counter.inc(1)
|
unique_id = self.lora_id_counter.inc(1)
|
||||||
self.lora_requests.append(
|
lora_request = LoRARequest(lora_name=lora_name,
|
||||||
LoRARequest(lora_name=lora_name,
|
|
||||||
lora_int_id=unique_id,
|
lora_int_id=unique_id,
|
||||||
lora_path=lora_path))
|
lora_path=lora_path)
|
||||||
|
if base_model_name is not None and self.is_base_model(base_model_name):
|
||||||
|
lora_request.base_model_name = base_model_name
|
||||||
|
|
||||||
|
# Validate that the adapter can be loaded into the engine
|
||||||
|
# This will also pre-load it for incoming requests
|
||||||
|
try:
|
||||||
|
await self.engine_client.add_lora(lora_request)
|
||||||
|
except ValueError as e:
|
||||||
|
# Adapter not found or lora configuration errors
|
||||||
|
if "No adapter found" in str(e):
|
||||||
|
return create_error_response(message=str(e),
|
||||||
|
err_type="NotFoundError",
|
||||||
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
else:
|
||||||
|
return create_error_response(
|
||||||
|
message=str(e),
|
||||||
|
err_type="BadRequestError",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
except BaseException as e:
|
||||||
|
# Some other unexpected problem loading the adapter, e.g. malformed
|
||||||
|
# input files.
|
||||||
|
# More detailed error messages for the user would be nicer here
|
||||||
|
return create_error_response(message=str(e),
|
||||||
|
err_type="BadRequestError",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
self.lora_requests.append(lora_request)
|
||||||
|
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
|
||||||
|
lora_path)
|
||||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||||
|
|
||||||
async def unload_lora_adapter(
|
async def unload_lora_adapter(
|
||||||
@ -155,6 +194,7 @@ class OpenAIServingModels:
|
|||||||
lora_request for lora_request in self.lora_requests
|
lora_request for lora_request in self.lora_requests
|
||||||
if lora_request.lora_name != lora_name
|
if lora_request.lora_name != lora_name
|
||||||
]
|
]
|
||||||
|
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||||
|
|
||||||
async def _check_load_lora_adapter_request(
|
async def _check_load_lora_adapter_request(
|
||||||
@ -195,8 +235,8 @@ class OpenAIServingModels:
|
|||||||
return create_error_response(
|
return create_error_response(
|
||||||
message=
|
message=
|
||||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||||
err_type="InvalidUserInput",
|
err_type="NotFoundError",
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,14 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
embedding_padding_modules=self.embedding_padding_modules,
|
embedding_padding_modules=self.embedding_padding_modules,
|
||||||
weights_mapper=hf_to_vllm_mapper)
|
weights_mapper=hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
# FileNotFoundError should be raised if both
|
||||||
|
# - No adapter found to download from huggingface (or in
|
||||||
|
# offline mode)
|
||||||
|
# - No local adapter files found at `lora_request.lora_path`
|
||||||
|
raise ValueError(
|
||||||
|
f"Loading lora {lora_request.lora_name} failed: No adapter "
|
||||||
|
f"found for {lora_path}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Loading lora {lora_path} failed") from e
|
raise RuntimeError(f"Loading lora {lora_path} failed") from e
|
||||||
if lora.rank > self.lora_config.max_lora_rank:
|
if lora.rank > self.lora_config.max_lora_rank:
|
||||||
@ -209,12 +217,19 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
|
|
||||||
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
||||||
if lora_request.lora_int_id not in self.list_adapters():
|
if lora_request.lora_int_id not in self.list_adapters():
|
||||||
# Remove before we load the new lora to save memory
|
# Load the new adapter first to ensure it is actually valid, before
|
||||||
|
# evicting any existing adapters.
|
||||||
|
# This may cause the # of loaded lora adapters to very temporarily
|
||||||
|
# exceed `--max-cpu-loras`.
|
||||||
|
lora = self._load_adapter(lora_request)
|
||||||
|
|
||||||
|
# Loading succeeded, now check if we will exceed cache capacity and
|
||||||
|
# evict if the oldest adapter if so
|
||||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||||
assert isinstance(self._adapter_manager,
|
assert isinstance(self._adapter_manager,
|
||||||
LRUCacheLoRAModelManager)
|
LRUCacheLoRAModelManager)
|
||||||
self._adapter_manager.remove_oldest_adapter()
|
self._adapter_manager.remove_oldest_adapter()
|
||||||
lora = self._load_adapter(lora_request)
|
# Then add the new adapter to the cache
|
||||||
loaded = self._adapter_manager.add_adapter(lora)
|
loaded = self._adapter_manager.add_adapter(lora)
|
||||||
else:
|
else:
|
||||||
# If the lora is already loaded, just touch it to
|
# If the lora is already loaded, just touch it to
|
||||||
|
|||||||
@ -339,3 +339,7 @@ class AsyncLLM(EngineClient):
|
|||||||
@property
|
@property
|
||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
return Exception() # TODO: implement
|
return Exception() # TODO: implement
|
||||||
|
|
||||||
|
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||||
|
"""Load a new LoRA adapter into the engine for future requests."""
|
||||||
|
raise NotImplementedError("LoRA not yet supported in V1")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user