Fix model name included in responses (#24663)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-11 18:47:51 +01:00 committed by GitHub
parent 4aa23892d6
commit c1eda615ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 50 additions and 74 deletions

View File

@ -12,7 +12,7 @@ import pytest_asyncio
import regex as re
import requests
import torch
from openai import BadRequestError, OpenAI
from openai import BadRequestError
from ...utils import RemoteOpenAIServer
@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
url = f"http://localhost:{server.port}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
}
data = {
# model_name is avoided here.
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "what is 1+1?"
}],
"max_tokens":
5
}
response = requests.post(url, headers=headers, json=data)
response_data = response.json()
print(response_data)
assert response_data.get("model") == MODEL_NAME
choice = response_data.get("choices")[0]
message = choice.get("message")
assert message is not None
content = message.get("content")
assert content is not None
assert len(content) > 0
@pytest.mark.asyncio
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
openai_api_key = "EMPTY"
openai_api_base = f"http://localhost:{server.port}/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
messages = [
{
"role": "user",
"content": "Hello, vLLM!"
},
]
response = client.chat.completions.create(
model="", # empty string
messages=messages,
)
assert response.model == MODEL_NAME
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI):

View File

@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
]
@dataclass
@ -270,6 +274,42 @@ def test_async_serving_chat_init():
assert serving_completion.chat_template == CHAT_TEMPLATE
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args):
return args[3]
serving_chat.chat_completion_full_generator = return_model_name
# Test that full name is returned when short name is requested
req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when empty string is specified
req = ChatCompletionRequest(model="", messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
# Test that full name is returned when no model is specified
req = ChatCompletionRequest(messages=messages)
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=MQLLMEngineClient)

View File

@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)

View File

@ -146,7 +146,7 @@ class ServingClassification(ClassificationMixin):
request: ClassificationRequest,
raw_request: Request,
) -> Union[ClassificationResponse, ErrorResponse]:
model_name = self._get_model_name(request.model)
model_name = self.models.model_name()
request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}")

View File

@ -232,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(request.model, lora_request)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the

View File

@ -599,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
model_name = self._get_model_name(request.model)
model_name = self.models.model_name()
request_id = (
f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request, request.request_id)}")

View File

@ -980,17 +980,6 @@ class OpenAIServing:
return True
return self.models.is_base_model(model_name)
def _get_model_name(
self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> str:
if lora_request:
return lora_request.lora_name
if not model_name:
return self.models.base_model_paths[0].name
return model_name
def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs,

View File

@ -91,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
model_name = self._get_model_name(request.model)
model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())

View File

@ -237,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if self.use_harmony:

View File

@ -353,7 +353,7 @@ class ServingScores(OpenAIServing):
final_res_batch,
request_id,
created_time,
self._get_model_name(request.model),
self.models.model_name(),
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
@ -399,7 +399,7 @@ class ServingScores(OpenAIServing):
return self.request_output_to_rerank_response(
final_res_batch,
request_id,
self._get_model_name(request.model),
self.models.model_name(),
documents,
top_n,
)