[FRONTEND] OpenAI tools support named functions (#5032)

This commit is contained in:
Breno Faria 2024-06-04 01:25:29 +02:00 committed by GitHub
parent 4f0d17c05c
commit f775a07e30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 314 additions and 11 deletions

View File

@ -109,4 +109,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:prog: -m vllm.entrypoints.openai.api_server
```
```
## Tool calling in the chat completion API
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
Please refer to the OpenAI API reference documentation for more information.

View File

@ -906,6 +906,191 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token in top_logprobs)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]
# non-streaming
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
})
message = chat_completion.choices[0].message
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
messages.append({"role": "assistant", "content": json_string})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})
# streaming
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
},
stream=True)
output = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
assert delta.content is None or len(delta.content) == 0
if delta.tool_calls:
output.append(delta.tool_calls[0].function.arguments)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="required")
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="auto")
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tool_choice={
"type": "function",
"function": {
"name":
"dummy_function_name"
}
})
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "nondefined_function_name"
}
})
@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):

View File

@ -24,7 +24,8 @@ class ServerRunner:
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,

View File

@ -102,6 +102,26 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]
class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@ -122,6 +142,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
@ -245,10 +268,27 @@ class ChatCompletionRequest(OpenAIBaseModel):
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
return data
@model_validator(mode="before")
@ -506,9 +546,21 @@ class EmbeddingResponse(BaseModel):
usage: UsageInfo
class FunctionCall(OpenAIBaseModel):
name: str
arguments: str
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall
class ChatMessage(OpenAIBaseModel):
role: str
content: str
tool_calls: List[ToolCall] = Field(default_factory=list)
class ChatCompletionLogProb(OpenAIBaseModel):
@ -535,7 +587,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
@ -545,6 +597,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
@ -557,7 +610,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@ -14,10 +14,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
@ -298,11 +299,24 @@ class OpenAIServingChat(OpenAIServing):
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
delta_message = DeltaMessage(content=delta_text)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
@ -324,7 +338,7 @@ class OpenAIServingChat(OpenAIServing):
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
@ -381,9 +395,22 @@ class OpenAIServingChat(OpenAIServing):
else:
logprobs = None
if request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
message=message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)

View File

@ -1,7 +1,8 @@
from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_decoding import (
@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
@ -23,3 +26,26 @@ async def get_guided_decoding_logits_processor(
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool
if request.tool_choice == "none":
return request
# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters
return request