[Misc][LoRA] Ensure Lora Adapter requests return adapter name (#11094)

Signed-off-by: Jiaxin Shan <seedjeffwan@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jiaxin Shan 2024-12-12 01:25:16 -08:00 committed by GitHub
parent 62de37a38e
commit 85362f028c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 7 deletions

View File

@ -9,6 +9,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.lora.request import LoRARequest
MODEL_NAME = "meta-llama/Llama-2-7b"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
@ -33,6 +34,16 @@ async def _async_serving_engine_init():
return serving_engine
@pytest.mark.asyncio
async def test_serving_model_name():
serving_engine = await _async_serving_engine_init()
assert serving_engine._get_model_name(None) == MODEL_NAME
request = LoRARequest(lora_name="adapter",
lora_path="/path/to/adapter2",
lora_int_id=1)
assert serving_engine._get_model_name(request) == request.lora_name
@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()

View File

@ -123,6 +123,8 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_name = self._get_model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tool_parser = self.tool_parser
@ -238,13 +240,13 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -259,11 +261,11 @@ class OpenAIServingChat(OpenAIServing):
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
@ -604,12 +606,12 @@ class OpenAIServingChat(OpenAIServing):
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
model_name: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
final_res: Optional[RequestOutput] = None

View File

@ -85,7 +85,6 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response(
"suffix is not currently supported")
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time())
@ -162,6 +161,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
model_name = self._get_model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the

View File

@ -661,3 +661,16 @@ class OpenAIServing:
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
def _get_model_name(self, lora: Optional[LoRARequest]):
"""
Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora is not None:
return lora.lora_name
return self.base_model_paths[0].name