[Misc][LoRA] Improve the readability of LoRA error messages (#12102)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-01-17 19:32:28 +08:00 committed by GitHub
parent 69d765f5a5
commit 07934cc237
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 243 additions and 114 deletions

View File

@ -17,6 +17,33 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# generation quality here # generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
BADREQUEST_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_lora_files(): def zephyr_lora_files():
@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI, @pytest.mark.parametrize("test_name,config_change,expected_error",
tmp_path, zephyr_lora_files): BADREQUEST_CASES)
invalid_rank = tmp_path / "invalid_rank" async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
zephyr_lora_files, test_name: str,
config_change: dict,
expected_error: str):
# Create test directory
test_dir = tmp_path / test_name
# Copy adapter from zephyr_lora_files to invalid_rank # Copy adapter files
shutil.copytree(zephyr_lora_files, invalid_rank) shutil.copytree(zephyr_lora_files, test_dir)
with open(invalid_rank / "adapter_config.json") as f: # Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f) adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)
print(adapter_config) # Save modified configuration
with open(config_path, "w") as f:
# assert False
# Change rank to invalid value
adapter_config["r"] = 1024
with open(invalid_rank / "adapter_config.json", "w") as f:
json.dump(adapter_config, f) json.dump(adapter_config, f)
with pytest.raises(openai.BadRequestError, # Test loading the adapter
match="is greater than max_lora_rank"): with pytest.raises(openai.BadRequestError, match=expected_error):
await client.post("load_lora_adapter", await client.post("load_lora_adapter",
cast_to=str, cast_to=str,
body={ body={
"lora_name": "invalid-json", "lora_name": test_name,
"lora_path": str(invalid_rank) "lora_path": str(test_dir)
}) })

View File

@ -3,6 +3,7 @@ from typing import List
import pytest import pytest
from vllm.lora.models import LoRAModel from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
@ -30,11 +31,14 @@ def test_load_checkpoints(
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
if lora_name == "baichuan7B": if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
# For the baichuan7B model, load it's LoRA, # For the baichuan7B model, load it's LoRA,
# and the test should pass. # and the test should pass.
LoRAModel.from_local_checkpoint( LoRAModel.from_local_checkpoint(
baichuan_lora_files, baichuan_lora_files,
expected_lora_modules, expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1, lora_model_id=1,
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
@ -43,9 +47,12 @@ def test_load_checkpoints(
# Test that the target_modules contain prefix # Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and # such as "model.layers.0.self_atten.W_pack", and
# the test should pass. # the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint( LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files, baichuan_zero_lora_files,
expected_lora_modules, expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1, lora_model_id=1,
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
@ -53,9 +60,12 @@ def test_load_checkpoints(
elif lora_name == "baichuan7B-zero-regex": elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions, # Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass. # such as `model\\..*(W_pack|o_proj)`, and the test should pass.
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
max_position_embeddings=4096)
LoRAModel.from_local_checkpoint( LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files, baichuan_regex_lora_files,
expected_lora_modules, expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1, lora_model_id=1,
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
@ -64,10 +74,13 @@ def test_load_checkpoints(
# For the baichuan7B model, load chatglm3-6b's LoRA, # For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error. # and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
max_position_embeddings=4096)
with pytest.raises(ValueError, match=expected_error): with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint( LoRAModel.from_local_checkpoint(
chatglm3_lora_files, chatglm3_lora_files,
expected_lora_modules, expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1, lora_model_id=1,
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
".layers.": ".baichuan_layers.", ".layers.": ".baichuan_layers.",
}, },
) )
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
max_position_embeddings=4096)
lora_model = LoRAModel.from_local_checkpoint( lora_model = LoRAModel.from_local_checkpoint(
baichuan_lora_files, baichuan_lora_files,
expected_lora_modules, expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1, lora_model_id=1,
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,

View File

@ -3,6 +3,7 @@ from typing import List
import pytest import pytest
from vllm.lora.models import LoRAModel from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_path = get_adapter_absolute_path(lora_name) lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and hugggingface id. # lora loading should work for either absolute path and hugggingface id.
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
lora_model = LoRAModel.from_local_checkpoint( lora_model = LoRAModel.from_local_checkpoint(
lora_path, lora_path,
expected_lora_modules, expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=1, lora_model_id=1,
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,

View File

@ -1,5 +1,3 @@
import json
import math
import os import os
from typing import Dict, List from typing import Dict, List
@ -34,56 +32,6 @@ DEVICES = ([
] if current_platform.is_cuda_alike() else ["cpu"]) ] if current_platform.is_cuda_alike() else ["cpu"])
def test_peft_helper(sql_lora_files):
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
peft_helper = PEFTHelper.from_dict(config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
scaling = peft_helper.lora_alpha / peft_helper.r
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
# test RSLoRA
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
peft_helper = PEFTHelper.from_dict(config)
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
config = dict(
r=8,
lora_alpha=16,
target_modules=["gate_proj"],
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_dora=True)
PEFTHelper.from_dict(config)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device): def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file( tensors = load_file(
@ -91,11 +39,8 @@ def test_from_lora_tensors(sql_lora_files, device):
new_embeddings = load_file( new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors")) os.path.join(sql_lora_files, "new_embeddings.safetensors"))
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json") peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
with open(lora_config_path) as f: max_position_embeddings=4096)
config = json.load(f)
peft_helper = PEFTHelper.from_dict(config)
lora_model = LoRAModel.from_lora_tensors( lora_model = LoRAModel.from_lora_tensors(
1, 1,
tensors, tensors,

View File

@ -0,0 +1,109 @@
import json
import math
import shutil
import pytest
from vllm.config import LoRAConfig
from vllm.lora.peft_helper import PEFTHelper
ERROR_CASES = [
(
"test_rank",
{
"r": 1024
},
"is greater than max_lora_rank",
),
(
"test_bias",
{
"bias": "all"
},
"Adapter bias cannot be used without bias_enabled",
),
("test_dora", {
"use_dora": True
}, "does not yet support DoRA"),
(
"test_modules_to_save",
{
"modules_to_save": ["lm_head"]
},
"only supports modules_to_save being None",
),
]
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
max_position_embeddings=4096)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
peft_helper.validate_legal(lora_config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
assert peft_helper.context_length == 16384
assert peft_helper.vllm_max_position_embeddings == 4096
assert peft_helper.vllm_long_context_scaling_factor == float(
math.ceil(peft_helper.context_length /
peft_helper.vllm_max_position_embeddings))
# test RSLoRA
rslora_config = dict(use_rslora=True)
test_dir = tmp_path / "test_rslora"
shutil.copytree(long_context_lora_files_16k_1, test_dir)
# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(rslora_config)
# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)
peft_helper = PEFTHelper.from_local_dir(test_dir,
max_position_embeddings=4096)
peft_helper.validate_legal(lora_config)
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
def test_peft_helper_error(
sql_lora_files,
tmp_path,
test_name: str,
config_change: dict,
expected_error: str,
):
test_dir = tmp_path / test_name
shutil.copytree(sql_lora_files, test_dir)
# Load and modify configuration
config_path = test_dir / "adapter_config.json"
with open(config_path) as f:
adapter_config = json.load(f)
# Apply configuration changes
adapter_config.update(config_change)
# Save modified configuration
with open(config_path, "w") as f:
json.dump(adapter_config, f)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
# Test loading the adapter
with pytest.raises(ValueError, match=expected_error):
PEFTHelper.from_local_dir(
test_dir, max_position_embeddings=4096).validate_legal(lora_config)

View File

@ -296,6 +296,7 @@ class MQLLMEngine:
is_engine_errored=False, is_engine_errored=False,
exception=e) exception=e)
self._send_outputs(rpc_err) self._send_outputs(rpc_err)
return
# Otherwise, send back the successful load message # Otherwise, send back the successful load message
self._send_outputs( self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id)) RPCAdapterLoadedResponse(request_id=request.request_id))

View File

@ -157,24 +157,16 @@ class OpenAIServingModels:
# This will also pre-load it for incoming requests # This will also pre-load it for incoming requests
try: try:
await self.engine_client.add_lora(lora_request) await self.engine_client.add_lora(lora_request)
except ValueError as e:
# Adapter not found or lora configuration errors
if "No adapter found" in str(e):
return create_error_response(message=str(e),
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
else:
return create_error_response(
message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
except BaseException as e: except BaseException as e:
# Some other unexpected problem loading the adapter, e.g. malformed error_type = "BadRequestError"
# input files. status_code = HTTPStatus.BAD_REQUEST
# More detailed error messages for the user would be nicer here if isinstance(e, ValueError) and "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
return create_error_response(message=str(e), return create_error_response(message=str(e),
err_type="BadRequestError", err_type=error_type,
status_code=HTTPStatus.BAD_REQUEST) status_code=status_code)
self.lora_requests.append(lora_request) self.lora_requests.append(lora_request)
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name, logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,

View File

@ -1,5 +1,4 @@
import copy import copy
import json
import math import math
import os import os
import re import re
@ -180,8 +179,8 @@ class LoRAModel(AdapterModel):
cls, cls,
lora_dir: str, lora_dir: str,
expected_lora_modules: List[str], expected_lora_modules: List[str],
peft_helper: PEFTHelper,
*, *,
max_position_embeddings: Optional[int] = None,
lora_model_id: Optional[int] = None, lora_model_id: Optional[int] = None,
device: str = "cuda", device: str = "cuda",
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
@ -196,9 +195,7 @@ class LoRAModel(AdapterModel):
lora_dir: The local path that has lora data. lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be expected_lora_modules: Name of modules that are expected to be
replaced by lora. replaced by lora.
max_position_embeddings: Max position embedding length. Used to peft_helper: Loaded lora configuration information.
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by lora_model_id: Lora model id. If not given, automatically set by
a global counter. a global counter.
device: Device where the lora model is loaded. device: Device where the lora model is loaded.
@ -207,18 +204,13 @@ class LoRAModel(AdapterModel):
Returns: Returns:
Loaded LoRA Model. Loaded LoRA Model.
""" """
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join( new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors") lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir, new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin") "new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]] unexpected_modules: List[Union[list[str], str]]
if os.path.isfile(lora_tensor_path): if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {} tensors: Dict[str, torch.Tensor] = {}

View File

@ -1,9 +1,12 @@
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py # Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
import json
import math import math
import os
from dataclasses import MISSING, dataclass, field, fields from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union from typing import List, Literal, Optional, Union
from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -11,6 +14,12 @@ logger = init_logger(__name__)
@dataclass @dataclass
class PEFTHelper: class PEFTHelper:
"""
A helper class for PEFT configurations, specifically designed for LoRA.
This class handles configuration validation, compatibility checks for
various LoRA implementations.
"""
# Required fields # Required fields
r: int r: int
lora_alpha: int lora_alpha: int
@ -29,20 +38,18 @@ class PEFTHelper:
vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None) vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self): def _validate_features(self) -> List[str]:
"""
Check if there are any unsupported Lora features.
"""
error_msg = [] error_msg = []
if self.modules_to_save: if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.") error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_dora: if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.") error_msg.append("vLLM does not yet support DoRA.")
return error_msg
if error_msg:
raise ValueError(f"{', '.join(error_msg)}")
def __post_init__(self): def __post_init__(self):
self._validate_features()
if self.use_rslora: if self.use_rslora:
logger.info_once("Loading LoRA weights trained with rsLoRA.") logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
@ -78,3 +85,29 @@ class PEFTHelper:
for k, v in config_dict.items() if k in class_fields for k, v in config_dict.items() if k in class_fields
} }
return cls(**filtered_dict) return cls(**filtered_dict)
@classmethod
def from_local_dir(cls, lora_path: str,
max_position_embeddings: Optional[int]) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config)
def validate_legal(self, lora_config: LoRAConfig) -> None:
"""
Validates the LoRA configuration settings against application
constraints and requirements.
"""
error_msg = self._validate_features()
if self.r > lora_config.max_lora_rank:
error_msg.append(
f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}.")
if self.bias != "none" and not lora_config.bias_enabled:
error_msg.append(
"Adapter bias cannot be used without bias_enabled.")
if error_msg:
raise ValueError(f"{' '.join(error_msg)}")

View File

@ -12,6 +12,7 @@ from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
@ -95,6 +96,13 @@ class WorkerLoRAManager(AbstractWorkerManager):
expected_lora_modules = list(set(expected_lora_modules)) expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(
lora_path, self.max_position_embeddings)
# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
peft_helper.validate_legal(self.lora_config)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights. # to ensure correct loading of lora weights.
hf_to_vllm_mapper = None hf_to_vllm_mapper = None
@ -105,7 +113,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
lora = self._lora_model_cls.from_local_checkpoint( lora = self._lora_model_cls.from_local_checkpoint(
lora_path, lora_path,
expected_lora_modules, expected_lora_modules,
max_position_embeddings=self.max_position_embeddings, peft_helper=peft_helper,
lora_model_id=lora_request.lora_int_id, lora_model_id=lora_request.lora_int_id,
device="cpu", device="cpu",
dtype=self.lora_config.lora_dtype, dtype=self.lora_config.lora_dtype,
@ -120,15 +128,14 @@ class WorkerLoRAManager(AbstractWorkerManager):
# - No adapter found to download from huggingface (or in # - No adapter found to download from huggingface (or in
# offline mode) # offline mode)
# - No local adapter files found at `lora_request.lora_path` # - No local adapter files found at `lora_request.lora_path`
# For NotFoundError
raise ValueError( raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter " f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_path}") from e f"found for {lora_path}") from e
except Exception as e: except Exception as e:
raise RuntimeError(f"Loading lora {lora_path} failed") from e # For BadRequestError
if lora.rank > self.lora_config.max_lora_rank: raise e
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size " f"is greater than lora_extra_vocab_size "