diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index c9947c54a918..4608850c7dae 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -12,7 +12,7 @@ import pytest_asyncio import regex as re import requests import torch -from openai import BadRequestError, OpenAI +from openai import BadRequestError from ...utils import RemoteOpenAIServer @@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): - url = f"http://localhost:{server.port}/v1/chat/completions" - headers = { - "Content-Type": "application/json", - } - data = { - # model_name is avoided here. - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "what is 1+1?" - }], - "max_tokens": - 5 - } - - response = requests.post(url, headers=headers, json=data) - response_data = response.json() - print(response_data) - assert response_data.get("model") == MODEL_NAME - choice = response_data.get("choices")[0] - message = choice.get("message") - assert message is not None - content = message.get("content") - assert content is not None - assert len(content) > 0 - - -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): - openai_api_key = "EMPTY" - openai_api_base = f"http://localhost:{server.port}/v1" - - client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, - ) - messages = [ - { - "role": "user", - "content": "Hello, vLLM!" - }, - ] - response = client.chat.completions.create( - model="", # empty string - messages=messages, - ) - assert response.model == MODEL_NAME - - @pytest.mark.asyncio async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 39a18b9daeca..d219a1f311f1 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" -BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT) +] @dataclass @@ -270,6 +274,42 @@ def test_async_serving_chat_init(): assert serving_completion.chat_template == CHAT_TEMPLATE +@pytest.mark.asyncio +async def test_serving_chat_returns_correct_model_name(): + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + messages = [{"role": "user", "content": "what is 1+1?"}] + + async def return_model_name(*args): + return args[3] + + serving_chat.chat_completion_full_generator = return_model_name + + # Test that full name is returned when short name is requested + req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when empty string is specified + req = ChatCompletionRequest(model="", messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when no model is specified + req = ChatCompletionRequest(messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + @pytest.mark.asyncio async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=MQLLMEngineClient) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5c7adc53f49b..579f6f537ee2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing): lora_request = self._maybe_get_adapters( request, supports_default_mm_loras=True) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index ef730442691c..7e88424c169c 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -146,7 +146,7 @@ class ServingClassification(ClassificationMixin): request: ClassificationRequest, raw_request: Request, ) -> Union[ClassificationResponse, ErrorResponse]: - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = (f"{self.request_id_prefix}-" f"{self._base_request_id(raw_request)}") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 10b1c54fe6a9..c2de449a9699 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -232,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing): result_generator = merge_async_iterators(*generators) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 476c21ad6a30..c0d1fe4b6e16 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -599,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = ( f"{self.request_id_prefix}-" f"{self._base_request_id(raw_request, request.request_id)}") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 708b1a5facf4..d391cc50ad23 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -980,17 +980,6 @@ class OpenAIServing: return True return self.models.is_base_model(model_name) - def _get_model_name( - self, - model_name: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - ) -> str: - if lora_request: - return lora_request.lora_name - if not model_name: - return self.models.base_model_paths[0].name - return model_name - def clamp_prompt_logprobs( prompt_logprobs: Union[PromptLogprobs, diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index d70e1a808aba..cac1d1ba5683 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -91,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index c5177bdf5375..401ba6c53331 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -237,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) if self.use_harmony: diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 847c014a11dc..24767ed66fc6 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -353,7 +353,7 @@ class ServingScores(OpenAIServing): final_res_batch, request_id, created_time, - self._get_model_name(request.model), + self.models.model_name(), ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -399,7 +399,7 @@ class ServingScores(OpenAIServing): return self.request_output_to_rerank_response( final_res_batch, request_id, - self._get_model_name(request.model), + self.models.model_name(), documents, top_n, )