mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:54:59 +08:00
Fix model name included in responses (#24663)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
4aa23892d6
commit
c1eda615ba
@ -12,7 +12,7 @@ import pytest_asyncio
|
|||||||
import regex as re
|
import regex as re
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from openai import BadRequestError, OpenAI
|
from openai import BadRequestError
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI):
|
|||||||
or "less_than_equal" in exc_info.value.message)
|
or "less_than_equal" in exc_info.value.message)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
|
|
||||||
url = f"http://localhost:{server.port}/v1/chat/completions"
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
# model_name is avoided here.
|
|
||||||
"messages": [{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant."
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}],
|
|
||||||
"max_tokens":
|
|
||||||
5
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data)
|
|
||||||
response_data = response.json()
|
|
||||||
print(response_data)
|
|
||||||
assert response_data.get("model") == MODEL_NAME
|
|
||||||
choice = response_data.get("choices")[0]
|
|
||||||
message = choice.get("message")
|
|
||||||
assert message is not None
|
|
||||||
content = message.get("content")
|
|
||||||
assert content is not None
|
|
||||||
assert len(content) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
|
|
||||||
openai_api_key = "EMPTY"
|
|
||||||
openai_api_base = f"http://localhost:{server.port}/v1"
|
|
||||||
|
|
||||||
client = OpenAI(
|
|
||||||
api_key=openai_api_key,
|
|
||||||
base_url=openai_api_base,
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello, vLLM!"
|
|
||||||
},
|
|
||||||
]
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="", # empty string
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
assert response.model == MODEL_NAME
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invocations(server: RemoteOpenAIServer,
|
async def test_invocations(server: RemoteOpenAIServer,
|
||||||
client: openai.AsyncOpenAI):
|
client: openai.AsyncOpenAI):
|
||||||
|
|||||||
@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
|||||||
|
|
||||||
|
|
||||||
MODEL_NAME = "openai-community/gpt2"
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
|
MODEL_NAME_SHORT = "gpt2"
|
||||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
BASE_MODEL_PATHS = [
|
||||||
|
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||||
|
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -270,6 +274,42 @@ def test_async_serving_chat_init():
|
|||||||
assert serving_completion.chat_template == CHAT_TEMPLATE
|
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serving_chat_returns_correct_model_name():
|
||||||
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
|
||||||
|
models = OpenAIServingModels(engine_client=mock_engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
model_config=MockModelConfig())
|
||||||
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
|
MockModelConfig(),
|
||||||
|
models,
|
||||||
|
response_role="assistant",
|
||||||
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
request_logger=None)
|
||||||
|
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||||
|
|
||||||
|
async def return_model_name(*args):
|
||||||
|
return args[3]
|
||||||
|
|
||||||
|
serving_chat.chat_completion_full_generator = return_model_name
|
||||||
|
|
||||||
|
# Test that full name is returned when short name is requested
|
||||||
|
req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
|
||||||
|
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
|
||||||
|
|
||||||
|
# Test that full name is returned when empty string is specified
|
||||||
|
req = ChatCompletionRequest(model="", messages=messages)
|
||||||
|
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
|
||||||
|
|
||||||
|
# Test that full name is returned when no model is specified
|
||||||
|
req = ChatCompletionRequest(messages=messages)
|
||||||
|
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_serving_chat_should_set_correct_max_tokens():
|
async def test_serving_chat_should_set_correct_max_tokens():
|
||||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
|
|||||||
@ -186,7 +186,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
lora_request = self._maybe_get_adapters(
|
lora_request = self._maybe_get_adapters(
|
||||||
request, supports_default_mm_loras=True)
|
request, supports_default_mm_loras=True)
|
||||||
|
|
||||||
model_name = self._get_model_name(request.model, lora_request)
|
model_name = self.models.model_name(lora_request)
|
||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
|||||||
@ -146,7 +146,7 @@ class ServingClassification(ClassificationMixin):
|
|||||||
request: ClassificationRequest,
|
request: ClassificationRequest,
|
||||||
raw_request: Request,
|
raw_request: Request,
|
||||||
) -> Union[ClassificationResponse, ErrorResponse]:
|
) -> Union[ClassificationResponse, ErrorResponse]:
|
||||||
model_name = self._get_model_name(request.model)
|
model_name = self.models.model_name()
|
||||||
request_id = (f"{self.request_id_prefix}-"
|
request_id = (f"{self.request_id_prefix}-"
|
||||||
f"{self._base_request_id(raw_request)}")
|
f"{self._base_request_id(raw_request)}")
|
||||||
|
|
||||||
|
|||||||
@ -232,7 +232,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
result_generator = merge_async_iterators(*generators)
|
result_generator = merge_async_iterators(*generators)
|
||||||
|
|
||||||
model_name = self._get_model_name(request.model, lora_request)
|
model_name = self.models.model_name(lora_request)
|
||||||
num_prompts = len(engine_prompts)
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
|
|||||||
@ -599,7 +599,7 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
|||||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
for the API specification. This API mimics the OpenAI Embedding API.
|
for the API specification. This API mimics the OpenAI Embedding API.
|
||||||
"""
|
"""
|
||||||
model_name = self._get_model_name(request.model)
|
model_name = self.models.model_name()
|
||||||
request_id = (
|
request_id = (
|
||||||
f"{self.request_id_prefix}-"
|
f"{self.request_id_prefix}-"
|
||||||
f"{self._base_request_id(raw_request, request.request_id)}")
|
f"{self._base_request_id(raw_request, request.request_id)}")
|
||||||
|
|||||||
@ -980,17 +980,6 @@ class OpenAIServing:
|
|||||||
return True
|
return True
|
||||||
return self.models.is_base_model(model_name)
|
return self.models.is_base_model(model_name)
|
||||||
|
|
||||||
def _get_model_name(
|
|
||||||
self,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
) -> str:
|
|
||||||
if lora_request:
|
|
||||||
return lora_request.lora_name
|
|
||||||
if not model_name:
|
|
||||||
return self.models.base_model_paths[0].name
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
|
|
||||||
def clamp_prompt_logprobs(
|
def clamp_prompt_logprobs(
|
||||||
prompt_logprobs: Union[PromptLogprobs,
|
prompt_logprobs: Union[PromptLogprobs,
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
model_name = self._get_model_name(request.model)
|
model_name = self.models.model_name()
|
||||||
|
|
||||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|||||||
@ -237,7 +237,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
lora_request = self._maybe_get_adapters(request)
|
lora_request = self._maybe_get_adapters(request)
|
||||||
model_name = self._get_model_name(request.model, lora_request)
|
model_name = self.models.model_name(lora_request)
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
|
|||||||
@ -353,7 +353,7 @@ class ServingScores(OpenAIServing):
|
|||||||
final_res_batch,
|
final_res_batch,
|
||||||
request_id,
|
request_id,
|
||||||
created_time,
|
created_time,
|
||||||
self._get_model_name(request.model),
|
self.models.model_name(),
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
@ -399,7 +399,7 @@ class ServingScores(OpenAIServing):
|
|||||||
return self.request_output_to_rerank_response(
|
return self.request_output_to_rerank_response(
|
||||||
final_res_batch,
|
final_res_batch,
|
||||||
request_id,
|
request_id,
|
||||||
self._get_model_name(request.model),
|
self.models.model_name(),
|
||||||
documents,
|
documents,
|
||||||
top_n,
|
top_n,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user