[HTTP Server] Make model param optional in request (#13568)

This commit is contained in:
Keyun Tong 2025-02-21 21:55:50 -08:00 committed by GitHub
parent 8c0dd3d4df
commit 0ffdf8ce0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 61 additions and 18 deletions

View File

@ -9,6 +9,7 @@ import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests
import torch import torch
from openai import BadRequestError from openai import BadRequestError
@ -996,3 +997,34 @@ async def test_long_seed(client: openai.AsyncOpenAI):
assert ("greater_than_equal" in exc_info.value.message assert ("greater_than_equal" in exc_info.value.message
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
async def test_http_chat_wo_model_name(server: RemoteOpenAIServer):
url = f"http://localhost:{server.port}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
}
data = {
# model_name is avoided here.
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "what is 1+1?"
}],
"max_tokens":
5
}
response = requests.post(url, headers=headers, json=data)
response_data = response.json()
print(response_data)
choice = response_data.get("choices")[0]
message = choice.get("message")
assert message is not None
content = message.get("content")
assert content is not None
assert len(content) > 0

View File

@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam] messages: List[ChatCompletionMessageParam]
model: str model: Optional[str] = None
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False logprobs: Optional[bool] = False
@ -642,7 +642,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
model: str model: Optional[str] = None
prompt: Union[List[int], List[List[int]], str, List[str]] prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None best_of: Optional[int] = None
echo: Optional[bool] = False echo: Optional[bool] = False
@ -907,7 +907,7 @@ class CompletionRequest(OpenAIBaseModel):
class EmbeddingCompletionRequest(OpenAIBaseModel): class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings # https://platform.openai.com/docs/api-reference/embeddings
model: str model: Optional[str] = None
input: Union[List[int], List[List[int]], str, List[str]] input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Literal["float", "base64"] = "float" encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None dimensions: Optional[int] = None
@ -939,7 +939,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
class EmbeddingChatRequest(OpenAIBaseModel): class EmbeddingChatRequest(OpenAIBaseModel):
model: str model: Optional[str] = None
messages: List[ChatCompletionMessageParam] messages: List[ChatCompletionMessageParam]
encoding_format: Literal["float", "base64"] = "float" encoding_format: Literal["float", "base64"] = "float"
@ -1007,7 +1007,7 @@ PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
class ScoreRequest(OpenAIBaseModel): class ScoreRequest(OpenAIBaseModel):
model: str model: Optional[str] = None
text_1: Union[List[str], str] text_1: Union[List[str], str]
text_2: Union[List[str], str] text_2: Union[List[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
@ -1031,7 +1031,7 @@ class ScoreRequest(OpenAIBaseModel):
class RerankRequest(OpenAIBaseModel): class RerankRequest(OpenAIBaseModel):
model: str model: Optional[str] = None
query: str query: str
documents: List[str] documents: List[str]
top_n: int = Field(default_factory=lambda: 0) top_n: int = Field(default_factory=lambda: 0)
@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel):
class TokenizeCompletionRequest(OpenAIBaseModel): class TokenizeCompletionRequest(OpenAIBaseModel):
model: str model: Optional[str] = None
prompt: str prompt: str
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
class TokenizeChatRequest(OpenAIBaseModel): class TokenizeChatRequest(OpenAIBaseModel):
model: str model: Optional[str] = None
messages: List[ChatCompletionMessageParam] messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field( add_generation_prompt: bool = Field(
@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel):
class DetokenizeRequest(OpenAIBaseModel): class DetokenizeRequest(OpenAIBaseModel):
model: str model: Optional[str] = None
tokens: List[int] tokens: List[int]
@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel):
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
""" """
model: str model: Optional[str] = None
"""ID of the model to use. """ID of the model to use.
""" """

View File

@ -141,7 +141,7 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request) model_name = self._get_model_name(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)

View File

@ -166,7 +166,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request) model_name = self._get_model_name(request.model, lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the

View File

@ -83,7 +83,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
model_name = request.model model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}" request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())

View File

@ -523,5 +523,16 @@ class OpenAIServing:
return logprob.decoded_token return logprob.decoded_token
return tokenizer.decode(token_id) return tokenizer.decode(token_id)
def _is_model_supported(self, model_name): def _is_model_supported(self, model_name) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name) return self.models.is_base_model(model_name)
def _get_model_name(self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> str:
if lora_request:
return lora_request.lora_name
if model_name is None:
return self.models.base_model_paths[0].name
return model_name

View File

@ -95,7 +95,7 @@ class OpenAIServingModels:
if isinstance(load_result, ErrorResponse): if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message) raise ValueError(load_result.message)
def is_base_model(self, model_name): def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths) return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:

View File

@ -79,7 +79,7 @@ class OpenAIServingPooling(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
model_name = request.model model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())

View File

@ -318,7 +318,7 @@ class ServingScores(OpenAIServing):
final_res_batch, final_res_batch,
request_id, request_id,
created_time, created_time,
request.model, self._get_model_name(request.model),
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
@ -358,7 +358,7 @@ class ServingScores(OpenAIServing):
request.truncate_prompt_tokens, request.truncate_prompt_tokens,
) )
return self.request_output_to_rerank_response( return self.request_output_to_rerank_response(
final_res_batch, request_id, request.model, documents, top_n) final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e: