Fix model name included in responses (#24663)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-11 18:47:51 +01:00 committed by GitHub
parent 4aa23892d6
commit c1eda615ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 50 additions and 74 deletions

View File

@ -12,7 +12,7 @@ import pytest_asyncio
import regex as re import regex as re
import requests import requests
import torch import torch
from openai import BadRequestError, OpenAI from openai import BadRequestError
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
url = f"http://localhost:{server.port}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
}
data = {
# model_name is avoided here.
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "what is 1+1?"
}],
"max_tokens":
5
}
response = requests.post(url, headers=headers, json=data)
response_data = response.json()
print(response_data)
assert response_data.get("model") == MODEL_NAME
choice = response_data.get("choices")[0]
message = choice.get("message")
assert message is not None
content = message.get("content")
assert content is not None
assert len(content) > 0
@pytest.mark.asyncio
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
messages = [
{
"role": "user",
"content": "Hello, vLLM!"
},
]
response = client.chat.completions.create(
model="", # empty string
messages=messages,
)
assert response.model == MODEL_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer, async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI): client: openai.AsyncOpenAI):

View File

@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
]
@dataclass @dataclass
@ -270,6 +274,42 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE assert serving_completion.chat_template == CHAT_TEMPLATE
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args):
return args[3]
serving_chat.chat_completion_full_generator = return_model_name
# Test that full name is returned when short name is requested
req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when empty string is specified
req = ChatCompletionRequest(model="", messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when no model is specified
req = ChatCompletionRequest(messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens(): async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine = MagicMock(spec=MQLLMEngineClient)

View File

@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_adapters( lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True) request, supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)

View File

@ -146,7 +146,7 @@ class ServingClassification(ClassificationMixin):
request: ClassificationRequest, request: ClassificationRequest,
raw_request: Request, raw_request: Request,
) -> Union[ClassificationResponse, ErrorResponse]: ) -> Union[ClassificationResponse, ErrorResponse]:
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = (f"{self.request_id_prefix}-" request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}") f"{self._base_request_id(raw_request)}")

View File

@ -232,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the

View File

@ -599,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API. for the API specification. This API mimics the OpenAI Embedding API.
""" """
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = ( request_id = (
f"{self.request_id_prefix}-" f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request, request.request_id)}") f"{self._base_request_id(raw_request, request.request_id)}")

View File

@ -980,17 +980,6 @@ class OpenAIServing:
return True return True
return self.models.is_base_model(model_name) return self.models.is_base_model(model_name)
def _get_model_name(
self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> str:
if lora_request:
return lora_request.lora_name
if not model_name:
return self.models.base_model_paths[0].name
return model_name
def clamp_prompt_logprobs( def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs, prompt_logprobs: Union[PromptLogprobs,

View File

@ -91,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
model_name = self._get_model_name(request.model) model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())

View File

@ -237,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing):
try: try:
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request) model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
if self.use_harmony: if self.use_harmony:

View File

@ -353,7 +353,7 @@ class ServingScores(OpenAIServing):
final_res_batch, final_res_batch,
request_id, request_id,
created_time, created_time,
self._get_model_name(request.model), self.models.model_name(),
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
@ -399,7 +399,7 @@ class ServingScores(OpenAIServing):
return self.request_output_to_rerank_response( return self.request_output_to_rerank_response(
final_res_batch, final_res_batch,
request_id, request_id,
self._get_model_name(request.model), self.models.model_name(),
documents, documents,
top_n, top_n,
) )