mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:35:50 +08:00
[Bugfix] Refactor /invocations to be task-agnostic (#20764)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7bd4c37ae7
commit
cbd14ed561
@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
|
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
|
||||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
|
||||||
model_name: str):
|
|
||||||
|
|
||||||
openai_api_key = "EMPTY"
|
openai_api_key = "EMPTY"
|
||||||
openai_api_base = f"http://localhost:{server.port}/v1"
|
openai_api_base = f"http://localhost:{server.port}/v1"
|
||||||
|
|
||||||
@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
assert response.model == MODEL_NAME
|
assert response.model == MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations(server: RemoteOpenAIServer,
|
||||||
|
client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"messages": messages,
|
||||||
|
"max_completion_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"logprobs": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
chat_completion = await client.chat.completions.create(**request_args)
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
chat_output = chat_completion.model_dump()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert chat_output.keys() == invocation_output.keys()
|
||||||
|
assert chat_output["choices"] == invocation_output["choices"]
|
||||||
|
|||||||
@ -155,3 +155,25 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
|
|||||||
assert output.object == "list"
|
assert output.object == "list"
|
||||||
assert isinstance(output.data, list)
|
assert isinstance(output.data, list)
|
||||||
assert len(output.data) == 0
|
assert len(output.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations(server: RemoteOpenAIServer):
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"input": "This product was excellent and exceeded my expectations"
|
||||||
|
}
|
||||||
|
|
||||||
|
classification_response = requests.post(server.url_for("classify"),
|
||||||
|
json=request_args)
|
||||||
|
classification_response.raise_for_status()
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
classification_output = classification_response.json()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert classification_output.keys() == invocation_output.keys()
|
||||||
|
assert classification_output["data"] == invocation_output["data"]
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import openai # use the official client for correctness check
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import regex as re
|
import regex as re
|
||||||
|
import requests
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
|||||||
assert content is not None and saying in content
|
assert content is not None and saying in content
|
||||||
else:
|
else:
|
||||||
assert content is not None and saying not in content
|
assert content is not None and saying not in content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations(server: RemoteOpenAIServer,
|
||||||
|
client: openai.AsyncOpenAI):
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"prompt": "Hello, my name is",
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
completion = await client.completions.create(**request_args)
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
completion_output = completion.model_dump()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert completion_output.keys() == invocation_output.keys()
|
||||||
|
assert completion_output["choices"] == invocation_output["choices"]
|
||||||
|
|||||||
@ -296,3 +296,63 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
|||||||
assert "error" in response.object
|
assert "error" in response.object
|
||||||
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
||||||
"Please, select a smaller truncation size." in response.message
|
"Please, select a smaller truncation size." in response.message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations(server: RemoteOpenAIServer,
|
||||||
|
client: openai.AsyncOpenAI):
|
||||||
|
input_texts = [
|
||||||
|
"The chef prepared a delicious meal.",
|
||||||
|
]
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"input": input_texts,
|
||||||
|
"encoding_format": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
completion_response = await client.embeddings.create(**request_args)
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
completion_output = completion_response.model_dump()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert completion_output.keys() == invocation_output.keys()
|
||||||
|
assert completion_output["data"] == invocation_output["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "The cat sat on the mat.",
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "A feline was resting on a rug.",
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Stars twinkle brightly in the night sky.",
|
||||||
|
}]
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"messages": messages,
|
||||||
|
"encoding_format": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||||
|
json=request_args)
|
||||||
|
chat_response.raise_for_status()
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
chat_output = chat_response.json()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert chat_output.keys() == invocation_output.keys()
|
||||||
|
assert chat_output["data"] == invocation_output["data"]
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
|||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
MODEL_NAME = "internlm/internlm2-1_8b-reward"
|
||||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
|||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
"--task",
|
"--task",
|
||||||
"classify",
|
"reward",
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
"bfloat16",
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"8192",
|
"512",
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
DUMMY_CHAT_TEMPLATE,
|
DUMMY_CHAT_TEMPLATE,
|
||||||
|
"--trust-remote-code",
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
assert poolings.id is not None
|
assert poolings.id is not None
|
||||||
assert len(poolings.data) == 1
|
assert len(poolings.data) == 1
|
||||||
assert len(poolings.data[0].data) == 2
|
assert len(poolings.data[0].data) == 8
|
||||||
assert poolings.usage.completion_tokens == 0
|
assert poolings.usage.completion_tokens == 0
|
||||||
assert poolings.usage.prompt_tokens == 7
|
assert poolings.usage.prompt_tokens == 8
|
||||||
assert poolings.usage.total_tokens == 7
|
assert poolings.usage.total_tokens == 8
|
||||||
|
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
input_tokens = [1, 1, 1, 1, 1]
|
input_tokens = [1, 1, 1, 1, 1]
|
||||||
@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
assert poolings.id is not None
|
assert poolings.id is not None
|
||||||
assert len(poolings.data) == 1
|
assert len(poolings.data) == 1
|
||||||
assert len(poolings.data[0].data) == 2
|
assert len(poolings.data[0].data) == 5
|
||||||
assert poolings.usage.completion_tokens == 0
|
assert poolings.usage.completion_tokens == 0
|
||||||
assert poolings.usage.prompt_tokens == 5
|
assert poolings.usage.prompt_tokens == 5
|
||||||
assert poolings.usage.total_tokens == 5
|
assert poolings.usage.total_tokens == 5
|
||||||
@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
assert poolings.id is not None
|
assert poolings.id is not None
|
||||||
assert len(poolings.data) == 3
|
assert len(poolings.data) == 3
|
||||||
assert len(poolings.data[0].data) == 2
|
assert len(poolings.data[0].data) == 8
|
||||||
assert poolings.usage.completion_tokens == 0
|
assert poolings.usage.completion_tokens == 0
|
||||||
assert poolings.usage.prompt_tokens == 25
|
assert poolings.usage.prompt_tokens == 29
|
||||||
assert poolings.usage.total_tokens == 25
|
assert poolings.usage.total_tokens == 29
|
||||||
|
|
||||||
# test list[list[int]]
|
# test list[list[int]]
|
||||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
|
|
||||||
assert poolings.id is not None
|
assert poolings.id is not None
|
||||||
assert len(poolings.data) == 4
|
assert len(poolings.data) == 4
|
||||||
assert len(poolings.data[0].data) == 2
|
assert len(poolings.data[0].data) == 5
|
||||||
assert poolings.usage.completion_tokens == 0
|
assert poolings.usage.completion_tokens == 0
|
||||||
assert poolings.usage.prompt_tokens == 17
|
assert poolings.usage.prompt_tokens == 17
|
||||||
assert poolings.usage.total_tokens == 17
|
assert poolings.usage.total_tokens == 17
|
||||||
@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
|
|||||||
chat_response.raise_for_status()
|
chat_response.raise_for_status()
|
||||||
chat_poolings = PoolingResponse.model_validate(chat_response.json())
|
chat_poolings = PoolingResponse.model_validate(chat_response.json())
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
tokenizer = get_tokenizer(
|
||||||
|
tokenizer_name=model_name,
|
||||||
|
tokenizer_mode="fast",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||||
@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
|||||||
)
|
)
|
||||||
float_response.raise_for_status()
|
float_response.raise_for_status()
|
||||||
responses_float = PoolingResponse.model_validate(float_response.json())
|
responses_float = PoolingResponse.model_validate(float_response.json())
|
||||||
|
float_data = [
|
||||||
|
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
|
||||||
|
]
|
||||||
|
|
||||||
base64_response = requests.post(
|
base64_response = requests.post(
|
||||||
server.url_for("pooling"),
|
server.url_for("pooling"),
|
||||||
@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
|||||||
np.frombuffer(base64.b64decode(data.data),
|
np.frombuffer(base64.b64decode(data.data),
|
||||||
dtype="float32").tolist())
|
dtype="float32").tolist())
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(embeddings_0_lst=float_data,
|
||||||
embeddings_0_lst=[d.data for d in responses_float.data],
|
embeddings_1_lst=decoded_responses_base64_data,
|
||||||
embeddings_1_lst=decoded_responses_base64_data,
|
name_0="float32",
|
||||||
name_0="float32",
|
name_1="base64")
|
||||||
name_1="base64")
|
|
||||||
|
|
||||||
# Default response is float32 decoded from base64 by OpenAI Client
|
# Default response is float32 decoded from base64 by OpenAI Client
|
||||||
default_response = requests.post(
|
default_response = requests.post(
|
||||||
@ -240,9 +247,71 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
|||||||
)
|
)
|
||||||
default_response.raise_for_status()
|
default_response.raise_for_status()
|
||||||
responses_default = PoolingResponse.model_validate(default_response.json())
|
responses_default = PoolingResponse.model_validate(default_response.json())
|
||||||
|
default_data = [
|
||||||
|
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
|
||||||
|
]
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(embeddings_0_lst=float_data,
|
||||||
embeddings_0_lst=[d.data for d in responses_default.data],
|
embeddings_1_lst=default_data,
|
||||||
embeddings_1_lst=[d.data for d in responses_default.data],
|
name_0="float32",
|
||||||
name_0="float32",
|
name_1="default")
|
||||||
name_1="base64")
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations(server: RemoteOpenAIServer):
|
||||||
|
input_texts = [
|
||||||
|
"The chef prepared a delicious meal.",
|
||||||
|
]
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"input": input_texts,
|
||||||
|
"encoding_format": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
completion_response = requests.post(server.url_for("pooling"),
|
||||||
|
json=request_args)
|
||||||
|
completion_response.raise_for_status()
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
completion_output = completion_response.json()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert completion_output.keys() == invocation_output.keys()
|
||||||
|
assert completion_output["data"] == invocation_output["data"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "The cat sat on the mat.",
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "A feline was resting on a rug.",
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Stars twinkle brightly in the night sky.",
|
||||||
|
}]
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"messages": messages,
|
||||||
|
"encoding_format": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
chat_response = requests.post(server.url_for("pooling"), json=request_args)
|
||||||
|
chat_response.raise_for_status()
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
chat_output = chat_response.json()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert chat_output.keys() == invocation_output.keys()
|
||||||
|
assert chat_output["data"] == invocation_output["data"]
|
||||||
|
|||||||
@ -94,3 +94,30 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
|||||||
# Assert just a small fragments of the response
|
# Assert just a small fragments of the response
|
||||||
assert "Please reduce the length of the input." in \
|
assert "Please reduce the length of the input." in \
|
||||||
rerank_response.text
|
rerank_response.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocations(server: RemoteOpenAIServer):
|
||||||
|
query = "What is the capital of France?"
|
||||||
|
documents = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
}
|
||||||
|
|
||||||
|
rerank_response = requests.post(server.url_for("rerank"),
|
||||||
|
json=request_args)
|
||||||
|
rerank_response.raise_for_status()
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
rerank_output = rerank_response.json()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert rerank_output.keys() == invocation_output.keys()
|
||||||
|
assert rerank_output["results"] == invocation_output["results"]
|
||||||
|
|||||||
@ -191,3 +191,28 @@ class TestModel:
|
|||||||
assert score_response.status_code == 400
|
assert score_response.status_code == 400
|
||||||
assert "Please, select a smaller truncation size." in \
|
assert "Please, select a smaller truncation size." in \
|
||||||
score_response.text
|
score_response.text
|
||||||
|
|
||||||
|
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
|
||||||
|
Any]):
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
text_2 = "The capital of France is Paris."
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": model["name"],
|
||||||
|
"text_1": text_1,
|
||||||
|
"text_2": text_2,
|
||||||
|
}
|
||||||
|
|
||||||
|
score_response = requests.post(server.url_for("score"),
|
||||||
|
json=request_args)
|
||||||
|
score_response.raise_for_status()
|
||||||
|
|
||||||
|
invocation_response = requests.post(server.url_for("invocations"),
|
||||||
|
json=request_args)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
score_output = score_response.json()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
assert score_output.keys() == invocation_output.keys()
|
||||||
|
assert score_output["data"] == invocation_output["data"]
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from collections.abc import AsyncIterator, Awaitable
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Callable, Optional
|
||||||
|
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
import pydantic
|
import pydantic
|
||||||
@ -61,13 +61,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
DetokenizeResponse,
|
DetokenizeResponse,
|
||||||
EmbeddingChatRequest,
|
|
||||||
EmbeddingCompletionRequest,
|
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse, ErrorResponse,
|
EmbeddingResponse, ErrorResponse,
|
||||||
LoadLoRAAdapterRequest,
|
LoadLoRAAdapterRequest,
|
||||||
PoolingChatRequest,
|
|
||||||
PoolingCompletionRequest,
|
|
||||||
PoolingRequest, PoolingResponse,
|
PoolingRequest, PoolingResponse,
|
||||||
RerankRequest, RerankResponse,
|
RerankRequest, RerankResponse,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
@ -434,6 +430,7 @@ async def get_server_load_metrics(request: Request):
|
|||||||
# - /v1/chat/completions
|
# - /v1/chat/completions
|
||||||
# - /v1/completions
|
# - /v1/completions
|
||||||
# - /v1/audio/transcriptions
|
# - /v1/audio/transcriptions
|
||||||
|
# - /v1/audio/translations
|
||||||
# - /v1/embeddings
|
# - /v1/embeddings
|
||||||
# - /pooling
|
# - /pooling
|
||||||
# - /classify
|
# - /classify
|
||||||
@ -957,31 +954,6 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
|||||||
return await do_rerank(request, raw_request)
|
return await do_rerank(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
|
|
||||||
"generate": {
|
|
||||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
|
||||||
"default": (CompletionRequest, create_completion),
|
|
||||||
},
|
|
||||||
"embed": {
|
|
||||||
"messages": (EmbeddingChatRequest, create_embedding),
|
|
||||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
|
||||||
},
|
|
||||||
"score": {
|
|
||||||
"default": (RerankRequest, do_rerank)
|
|
||||||
},
|
|
||||||
"rerank": {
|
|
||||||
"default": (RerankRequest, do_rerank)
|
|
||||||
},
|
|
||||||
"reward": {
|
|
||||||
"messages": (PoolingChatRequest, create_pooling),
|
|
||||||
"default": (PoolingCompletionRequest, create_pooling),
|
|
||||||
},
|
|
||||||
"classify": {
|
|
||||||
"messages": (PoolingChatRequest, create_pooling),
|
|
||||||
"default": (PoolingCompletionRequest, create_pooling),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if envs.VLLM_SERVER_DEV_MODE:
|
if envs.VLLM_SERVER_DEV_MODE:
|
||||||
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
||||||
"This should NOT be used in production!")
|
"This should NOT be used in production!")
|
||||||
@ -1033,6 +1005,30 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||||
|
# (requires typing_extensions >= 4.13)
|
||||||
|
RequestType = Any
|
||||||
|
GetHandlerFn = Callable[[Request], Optional[OpenAIServing]]
|
||||||
|
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||||
|
|
||||||
|
# NOTE: Items defined earlier take higher priority
|
||||||
|
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||||
|
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||||
|
(CompletionRequest, (completion, create_completion)),
|
||||||
|
(EmbeddingRequest, (embedding, create_embedding)),
|
||||||
|
(ClassificationRequest, (classify, create_classify)),
|
||||||
|
(ScoreRequest, (score, create_score)),
|
||||||
|
(RerankRequest, (rerank, do_rerank)),
|
||||||
|
(PoolingRequest, (pooling, create_pooling)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: Construct the TypeAdapters only once
|
||||||
|
INVOCATION_VALIDATORS = [
|
||||||
|
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||||
|
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/invocations",
|
@router.post("/invocations",
|
||||||
dependencies=[Depends(validate_json_request)],
|
dependencies=[Depends(validate_json_request)],
|
||||||
responses={
|
responses={
|
||||||
@ -1047,32 +1043,34 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
async def invocations(raw_request: Request):
|
async def invocations(raw_request: Request):
|
||||||
"""
|
"""For SageMaker, routes requests based on the request type."""
|
||||||
For SageMaker, routes requests to other handlers based on model `task`.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
body = await raw_request.json()
|
body = await raw_request.json()
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
detail=f"JSON decode error: {e}") from e
|
detail=f"JSON decode error: {e}") from e
|
||||||
|
|
||||||
task = raw_request.app.state.task
|
valid_endpoints = [(validator, endpoint)
|
||||||
|
for validator, (get_handler,
|
||||||
|
endpoint) in INVOCATION_VALIDATORS
|
||||||
|
if get_handler(raw_request) is not None]
|
||||||
|
|
||||||
if task not in TASK_HANDLERS:
|
for request_validator, endpoint in valid_endpoints:
|
||||||
raise HTTPException(
|
try:
|
||||||
status_code=400,
|
request = request_validator.validate_python(body)
|
||||||
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
except pydantic.ValidationError:
|
||||||
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
continue
|
||||||
|
|
||||||
handler_config = TASK_HANDLERS[task]
|
return await endpoint(request, raw_request)
|
||||||
if "messages" in body:
|
|
||||||
request_model, handler = handler_config["messages"]
|
|
||||||
else:
|
|
||||||
request_model, handler = handler_config["default"]
|
|
||||||
|
|
||||||
# this is required since we lose the FastAPI automatic casting
|
type_names = [
|
||||||
request = request_model.model_validate(body)
|
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||||
return await handler(request, raw_request)
|
for validator, _ in valid_endpoints
|
||||||
|
]
|
||||||
|
msg = ("Cannot find suitable handler for request. "
|
||||||
|
f"Expected one of: {type_names}")
|
||||||
|
res = base(raw_request).create_error_response(message=msg)
|
||||||
|
return JSONResponse(content=res.model_dump(), status_code=res.code)
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
|||||||
@ -11,6 +11,12 @@ from typing import Annotated, Any, ClassVar, Literal, Optional, Union
|
|||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from fastapi import HTTPException, UploadFile
|
from fastapi import HTTPException, UploadFile
|
||||||
|
# yapf: disable
|
||||||
|
from openai.types.chat.chat_completion_audio import (
|
||||||
|
ChatCompletionAudio as OpenAIChatCompletionAudio)
|
||||||
|
from openai.types.chat.chat_completion_message import (
|
||||||
|
Annotation as OpenAIAnnotation)
|
||||||
|
# yapf: enable
|
||||||
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
|
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
|
||||||
ResponseOutputMessage, ResponsePrompt,
|
ResponseOutputMessage, ResponsePrompt,
|
||||||
ResponseStatus, ResponseTextConfig)
|
ResponseStatus, ResponseTextConfig)
|
||||||
@ -1393,11 +1399,16 @@ class CompletionResponseChoice(OpenAIBaseModel):
|
|||||||
|
|
||||||
class CompletionResponse(OpenAIBaseModel):
|
class CompletionResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
object: str = "text_completion"
|
object: Literal["text_completion"] = "text_completion"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: list[CompletionResponseChoice]
|
choices: list[CompletionResponseChoice]
|
||||||
|
service_tier: Optional[Literal["auto", "default", "flex", "scale",
|
||||||
|
"priority"]] = None
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
# vLLM-specific fields that are not in OpenAI spec
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
default=None, description="KVTransfer parameters.")
|
default=None, description="KVTransfer parameters.")
|
||||||
|
|
||||||
@ -1549,10 +1560,16 @@ class ExtractedToolCallInformation(BaseModel):
|
|||||||
|
|
||||||
class ChatMessage(OpenAIBaseModel):
|
class ChatMessage(OpenAIBaseModel):
|
||||||
role: str
|
role: str
|
||||||
reasoning_content: Optional[str] = None
|
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
refusal: Optional[str] = None
|
||||||
|
annotations: Optional[OpenAIAnnotation] = None
|
||||||
|
audio: Optional[OpenAIChatCompletionAudio] = None
|
||||||
|
function_call: Optional[FunctionCall] = None
|
||||||
tool_calls: list[ToolCall] = Field(default_factory=list)
|
tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# vLLM-specific fields that are not in OpenAI spec
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||||
token: str
|
token: str
|
||||||
@ -1587,7 +1604,12 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
|||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: list[ChatCompletionResponseChoice]
|
choices: list[ChatCompletionResponseChoice]
|
||||||
|
service_tier: Optional[Literal["auto", "default", "flex", "scale",
|
||||||
|
"priority"]] = None
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
# vLLM-specific fields that are not in OpenAI spec
|
||||||
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
default=None, description="KVTransfer parameters.")
|
default=None, description="KVTransfer parameters.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user