mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:40:44 +08:00
[HTTP Server] Make model param optional in request (#13568)
This commit is contained in:
parent
8c0dd3d4df
commit
0ffdf8ce0c
@ -9,6 +9,7 @@ import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
import torch
|
||||
from openai import BadRequestError
|
||||
|
||||
@ -996,3 +997,34 @@ async def test_long_seed(client: openai.AsyncOpenAI):
|
||||
|
||||
assert ("greater_than_equal" in exc_info.value.message
|
||||
or "less_than_equal" in exc_info.value.message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_chat_wo_model_name(server: RemoteOpenAIServer):
|
||||
url = f"http://localhost:{server.port}/v1/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
# model_name is avoided here.
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
"max_tokens":
|
||||
5
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response_data = response.json()
|
||||
print(response_data)
|
||||
|
||||
choice = response_data.get("choices")[0]
|
||||
message = choice.get("message")
|
||||
assert message is not None
|
||||
content = message.get("content")
|
||||
assert content is not None
|
||||
assert len(content) > 0
|
||||
|
||||
@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
@ -642,7 +642,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
@ -907,7 +907,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
@ -939,7 +939,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
@ -1007,7 +1007,7 @@ PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
text_1: Union[List[str], str]
|
||||
text_2: Union[List[str], str]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
@ -1031,7 +1031,7 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
query: str
|
||||
documents: List[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel):
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(
|
||||
@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel):
|
||||
|
||||
|
||||
class DetokenizeRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model: Optional[str] = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
|
||||
@ -141,7 +141,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
|
||||
@ -83,7 +83,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"embd-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
@ -523,5 +523,16 @@ class OpenAIServing:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
def _is_model_supported(self, model_name) -> bool:
|
||||
if not model_name:
|
||||
return True
|
||||
return self.models.is_base_model(model_name)
|
||||
|
||||
def _get_model_name(self,
|
||||
model_name: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> str:
|
||||
if lora_request:
|
||||
return lora_request.lora_name
|
||||
if model_name is None:
|
||||
return self.models.base_model_paths[0].name
|
||||
return model_name
|
||||
|
||||
@ -95,7 +95,7 @@ class OpenAIServingModels:
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.message)
|
||||
|
||||
def is_base_model(self, model_name):
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
|
||||
|
||||
@ -79,7 +79,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
model_name = self._get_model_name(request.model)
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
@ -318,7 +318,7 @@ class ServingScores(OpenAIServing):
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
request.model,
|
||||
self._get_model_name(request.model),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
@ -358,7 +358,7 @@ class ServingScores(OpenAIServing):
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch, request_id, request.model, documents, top_n)
|
||||
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user