diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 4b5ad55c5eda0..d7ed4afa28611 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -9,6 +9,7 @@ import jsonschema import openai # use the official client for correctness check import pytest import pytest_asyncio +import requests import torch from openai import BadRequestError @@ -996,3 +997,34 @@ async def test_long_seed(client: openai.AsyncOpenAI): assert ("greater_than_equal" in exc_info.value.message or "less_than_equal" in exc_info.value.message) + + +@pytest.mark.asyncio +async def test_http_chat_wo_model_name(server: RemoteOpenAIServer): + url = f"http://localhost:{server.port}/v1/chat/completions" + headers = { + "Content-Type": "application/json", + } + data = { + # model_name is avoided here. + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "what is 1+1?" + }], + "max_tokens": + 5 + } + + response = requests.post(url, headers=headers, json=data) + response_data = response.json() + print(response_data) + + choice = response_data.get("choices")[0] + message = choice.get("message") + assert message is not None + content = message.get("content") + assert content is not None + assert len(content) > 0 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 29f64d28bdf1c..45b98a032bda9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create messages: List[ChatCompletionMessageParam] - model: str + model: Optional[str] = None frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False @@ -642,7 +642,7 @@ class ChatCompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create - model: str + model: Optional[str] = None prompt: Union[List[int], List[List[int]], str, List[str]] best_of: Optional[int] = None echo: Optional[bool] = False @@ -907,7 +907,7 @@ class CompletionRequest(OpenAIBaseModel): class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings - model: str + model: Optional[str] = None input: Union[List[int], List[List[int]], str, List[str]] encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None @@ -939,7 +939,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): class EmbeddingChatRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None messages: List[ChatCompletionMessageParam] encoding_format: Literal["float", "base64"] = "float" @@ -1007,7 +1007,7 @@ PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] class ScoreRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None text_1: Union[List[str], str] text_2: Union[List[str], str] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None @@ -1031,7 +1031,7 @@ class ScoreRequest(OpenAIBaseModel): class RerankRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None query: str documents: List[str] top_n: int = Field(default_factory=lambda: 0) @@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel): class TokenizeCompletionRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None prompt: str add_special_tokens: bool = Field( @@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel): class TokenizeChatRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None messages: List[ChatCompletionMessageParam] add_generation_prompt: bool = Field( @@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel): class DetokenizeRequest(OpenAIBaseModel): - model: str + model: Optional[str] = None tokens: List[int] @@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel): formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. """ - model: str + model: Optional[str] = None """ID of the model to use. """ diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 934bd2a95063c..02dd2c4881c62 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -141,7 +141,7 @@ class OpenAIServingChat(OpenAIServing): prompt_adapter_request, ) = self._maybe_get_adapters(request) - model_name = self.models.model_name(lora_request) + model_name = self._get_model_name(request.model, lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e7ad263e7fbe5..840f0f9b8448b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -166,7 +166,7 @@ class OpenAIServingCompletion(OpenAIServing): result_generator = merge_async_iterators(*generators) - model_name = self.models.model_name(lora_request) + model_name = self._get_model_name(request.model, lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 45f8ad90ddcb3..607dbd96b1945 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -83,7 +83,7 @@ class OpenAIServingEmbedding(OpenAIServing): return self.create_error_response( "dimensions is currently not supported") - model_name = request.model + model_name = self._get_model_name(request.model) request_id = f"embd-{self._base_request_id(raw_request)}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5619e509c5544..05b5f95a5e59c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -523,5 +523,16 @@ class OpenAIServing: return logprob.decoded_token return tokenizer.decode(token_id) - def _is_model_supported(self, model_name): + def _is_model_supported(self, model_name) -> bool: + if not model_name: + return True return self.models.is_base_model(model_name) + + def _get_model_name(self, + model_name: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> str: + if lora_request: + return lora_request.lora_name + if model_name is None: + return self.models.base_model_paths[0].name + return model_name diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index f917a48519016..6ade4ece6d034 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -95,7 +95,7 @@ class OpenAIServingModels: if isinstance(load_result, ErrorResponse): raise ValueError(load_result.message) - def is_base_model(self, model_name): + def is_base_model(self, model_name) -> bool: return any(model.name == model_name for model in self.base_model_paths) def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 01a3d211f6ba6..bbf5aed1a33c8 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -79,7 +79,7 @@ class OpenAIServingPooling(OpenAIServing): return self.create_error_response( "dimensions is currently not supported") - model_name = request.model + model_name = self._get_model_name(request.model) request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 0e9b355ad4f99..01e2d30436101 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -318,7 +318,7 @@ class ServingScores(OpenAIServing): final_res_batch, request_id, created_time, - request.model, + self._get_model_name(request.model), ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -358,7 +358,7 @@ class ServingScores(OpenAIServing): request.truncate_prompt_tokens, ) return self.request_output_to_rerank_response( - final_res_batch, request_id, request.model, documents, top_n) + final_res_batch, request_id, self._get_model_name(request.model), documents, top_n) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: