[tech debt] Revisit lora request model checker (#20636)

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
kourosh hakhamaneshi 2025-07-08 18:42:41 -07:00 committed by GitHub
parent 0b407479ef
commit baed180aa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 58 deletions

View File

@ -57,7 +57,8 @@ async def test_load_lora_adapter_success():
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_models.lora_requests) == 1
assert serving_models.lora_requests[0].lora_name == "adapter"
assert "adapter" in serving_models.lora_requests
assert serving_models.lora_requests["adapter"].lora_name == "adapter"
@pytest.mark.asyncio

View File

@ -438,9 +438,7 @@ class OpenAIServing:
if self._is_model_supported(request.model):
return None
if request.model in [
lora.lora_name for lora in self.models.lora_requests
]:
if request.model in self.models.lora_requests:
return None
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
load_result := await self.models.resolve_lora(request.model)):
@ -466,9 +464,8 @@ class OpenAIServing:
None, PromptAdapterRequest]]:
if self._is_model_supported(request.model):
return None, None
for lora in self.models.lora_requests:
if request.model == lora.lora_name:
return lora, None
if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model], None
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter

View File

@ -65,12 +65,13 @@ class OpenAIServingModels:
super().__init__()
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client
self.model_config = model_config
self.static_lora_modules = lora_modules
self.lora_requests: list[LoRARequest] = []
self.lora_requests: dict[str, LoRARequest] = {}
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
@ -138,7 +139,7 @@ class OpenAIServingModels:
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
for lora in self.lora_requests.values()
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
@ -155,53 +156,60 @@ class OpenAIServingModels:
request: LoadLoRAAdapterRequest,
base_model_name: Optional[str] = None
) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
lora_request = LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path)
if base_model_name is not None and self.is_base_model(base_model_name):
lora_request.base_model_name = base_model_name
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_load_lora_adapter_request(
request)
if error_check_ret is not None:
return error_check_ret
# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except BaseException as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
lora_path = request.lora_path
unique_id = self.lora_id_counter.inc(1)
lora_request = LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path)
if base_model_name is not None and self.is_base_model(
base_model_name):
lora_request.base_model_name = base_model_name
return create_error_response(message=str(e),
err_type=error_type,
status_code=status_code)
# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except Exception as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
self.lora_requests.append(lora_request)
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
lora_path)
return f"Success: LoRA adapter '{lora_name}' added successfully."
return create_error_response(message=str(e),
err_type=error_type,
status_code=status_code)
self.lora_requests[lora_name] = lora_request
logger.info("Loaded new LoRA adapter: name '%s', path '%s'",
lora_name, lora_path)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
# Ensure atomicity based on the lora name
async with self.lora_resolver_lock[lora_name]:
error_check_ret = await self._check_unload_lora_adapter_request(
request)
if error_check_ret is not None:
return error_check_ret
# Safe to delete now since we hold the lock
del self.lora_requests[lora_name]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]:
@ -213,8 +221,7 @@ class OpenAIServingModels:
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
if request.lora_name in self.lora_requests:
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been "
@ -227,17 +234,16 @@ class OpenAIServingModels:
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
# Check if 'lora_name' is not provided return an error
if not request.lora_name:
return create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
"'lora_name' needs to be provided to unload a LoRA adapter.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
if request.lora_name not in self.lora_requests:
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
@ -260,9 +266,8 @@ class OpenAIServingModels:
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
for existing in self.lora_requests:
if existing.lora_name == lora_name:
return existing
if lora_name in self.lora_requests:
return self.lora_requests[lora_name]
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
@ -279,7 +284,7 @@ class OpenAIServingModels:
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests.append(lora_request)
self.lora_requests[lora_name] = lora_request
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name, resolver.__class__.__name__)