mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[FRONTEND] OpenAI tools support named functions (#5032)
This commit is contained in:
parent
4f0d17c05c
commit
f775a07e30
@ -109,4 +109,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
||||
:module: vllm.entrypoints.openai.cli_args
|
||||
:func: make_arg_parser
|
||||
:prog: -m vllm.entrypoints.openai.api_server
|
||||
```
|
||||
```
|
||||
|
||||
## Tool calling in the chat completion API
|
||||
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
|
||||
|
||||
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
|
||||
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
|
||||
|
||||
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||
|
||||
Please refer to the OpenAI API reference documentation for more information.
|
||||
@ -906,6 +906,191 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
for token in top_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
}]
|
||||
|
||||
# non-streaming
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name"
|
||||
}
|
||||
})
|
||||
message = chat_completion.choices[0].message
|
||||
assert len(message.content) == 0
|
||||
json_string = message.tool_calls[0].function.arguments
|
||||
json1 = json.loads(json_string)
|
||||
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
||||
|
||||
messages.append({"role": "assistant", "content": json_string})
|
||||
messages.append({
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Give me another one with a different name and age"
|
||||
})
|
||||
|
||||
# streaming
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name"
|
||||
}
|
||||
},
|
||||
stream=True)
|
||||
|
||||
output = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role:
|
||||
assert delta.role == "assistant"
|
||||
assert delta.content is None or len(delta.content) == 0
|
||||
if delta.tool_calls:
|
||||
output.append(delta.tool_calls[0].function.arguments)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
json2 = json.loads("".join(output))
|
||||
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
||||
assert json1["name"] != json2["name"]
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
||||
async def test_required_tool_use_not_yet_supported(
|
||||
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
}
|
||||
}],
|
||||
tool_choice="required")
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
}
|
||||
}],
|
||||
tool_choice="auto")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
||||
async def test_inconsistent_tool_choice_and_tools(
|
||||
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
"dummy_function_name"
|
||||
}
|
||||
})
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "nondefined_function_name"
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
for _ in range(2):
|
||||
|
||||
@ -24,7 +24,8 @@ class ServerRunner:
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
||||
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
|
||||
args,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
|
||||
@ -102,6 +102,26 @@ class ResponseFormat(OpenAIBaseModel):
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class FunctionDefinition(OpenAIBaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
function: ChatCompletionNamedFunction
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
@ -122,6 +142,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
@ -245,10 +268,27 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
# you can only use one kind of guided decoding
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and "tool_choice" in data and data[
|
||||
"tool_choice"] != "none":
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tool_choice(cls, data):
|
||||
if "tool_choice" in data and data["tool_choice"] != "none":
|
||||
if not isinstance(data["tool_choice"], dict):
|
||||
raise ValueError("Currently only named tools are supported.")
|
||||
if "tools" not in data or data["tools"] is None:
|
||||
raise ValueError(
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -506,9 +546,21 @@ class EmbeddingResponse(BaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||
@ -535,7 +587,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
|
||||
class ChatCompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion"
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
@ -545,6 +597,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
@ -557,7 +610,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
|
||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion.chunk"
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
@ -14,10 +14,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
||||
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
||||
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
UsageInfo)
|
||||
FunctionCall, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
@ -298,11 +299,24 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
|
||||
if request.tool_choice and type(
|
||||
request.tool_choice
|
||||
) is ChatCompletionNamedToolChoiceParam:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=delta_text))
|
||||
])
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
@ -324,7 +338,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
@ -381,9 +395,22 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output.text))
|
||||
])
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||
ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
request = _adapt_request_for_tool_use(request)
|
||||
|
||||
if guided_decoding_backend == 'outlines':
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
@ -23,3 +26,26 @@ async def get_guided_decoding_logits_processor(
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
|
||||
|
||||
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
|
||||
ChatCompletionRequest]):
|
||||
# the legacy completion API does not support tool use
|
||||
if type(request) is CompletionRequest:
|
||||
return request
|
||||
|
||||
# user has chosen to not use any tool
|
||||
if request.tool_choice == "none":
|
||||
return request
|
||||
|
||||
# user has chosen to use a named tool
|
||||
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
tool_name = request.tool_choice.function.name
|
||||
tools = {tool.function.name: tool.function for tool in request.tools}
|
||||
if tool_name not in tools:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
tool = tools[tool_name]
|
||||
request.guided_json = tool.parameters
|
||||
|
||||
return request
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user