mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 07:14:59 +08:00
[Misc][LoRA] Ensure Lora Adapter requests return adapter name (#11094)
Signed-off-by: Jiaxin Shan <seedjeffwan@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
62de37a38e
commit
85362f028c
@ -9,6 +9,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
@ -33,6 +34,16 @@ async def _async_serving_engine_init():
|
||||
return serving_engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_model_name():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
assert serving_engine._get_model_name(None) == MODEL_NAME
|
||||
request = LoRARequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2",
|
||||
lora_int_id=1)
|
||||
assert serving_engine._get_model_name(request) == request.lora_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_success():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
|
||||
@ -123,6 +123,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_name = self._get_model_name(lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
@ -238,13 +240,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
request, result_generator, request_id, model_name,
|
||||
conversation, tokenizer, request_metadata)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
request, result_generator, request_id, model_name,
|
||||
conversation, tokenizer, request_metadata)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -259,11 +261,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
@ -604,12 +606,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
|
||||
@ -85,7 +85,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
@ -162,6 +161,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
model_name = self._get_model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
|
||||
@ -661,3 +661,16 @@ class OpenAIServing:
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def _get_model_name(self, lora: Optional[LoRARequest]):
|
||||
"""
|
||||
Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora is not None:
|
||||
return lora.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user