mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:35:32 +08:00
[FRONTEND] OpenAI tools support named functions (#5032)
This commit is contained in:
parent
4f0d17c05c
commit
f775a07e30
@ -110,3 +110,14 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
|||||||
:func: make_arg_parser
|
:func: make_arg_parser
|
||||||
:prog: -m vllm.entrypoints.openai.api_server
|
:prog: -m vllm.entrypoints.openai.api_server
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Tool calling in the chat completion API
|
||||||
|
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
|
||||||
|
|
||||||
|
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
|
||||||
|
|
||||||
|
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
|
||||||
|
|
||||||
|
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||||
|
|
||||||
|
Please refer to the OpenAI API reference documentation for more information.
|
||||||
@ -906,6 +906,191 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
for token in top_logprobs)
|
for token in top_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
f"Give an example JSON for an employee profile that "
|
||||||
|
f"fits this schema: {TEST_SCHEMA}"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# non-streaming
|
||||||
|
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1000,
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": TEST_SCHEMA
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert len(message.content) == 0
|
||||||
|
json_string = message.tool_calls[0].function.arguments
|
||||||
|
json1 = json.loads(json_string)
|
||||||
|
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": json_string})
|
||||||
|
messages.append({
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Give me another one with a different name and age"
|
||||||
|
})
|
||||||
|
|
||||||
|
# streaming
|
||||||
|
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1000,
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": TEST_SCHEMA
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
stream=True)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
finish_reason_count = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.role:
|
||||||
|
assert delta.role == "assistant"
|
||||||
|
assert delta.content is None or len(delta.content) == 0
|
||||||
|
if delta.tool_calls:
|
||||||
|
output.append(delta.tool_calls[0].function.arguments)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
# finish reason should only return in last block
|
||||||
|
assert finish_reason_count == 1
|
||||||
|
json2 = json.loads("".join(output))
|
||||||
|
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
||||||
|
assert json1["name"] != json2["name"]
|
||||||
|
assert json1["age"] != json2["age"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
||||||
|
async def test_required_tool_use_not_yet_supported(
|
||||||
|
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
f"Give an example JSON for an employee profile that "
|
||||||
|
f"fits this schema: {TEST_SCHEMA}"
|
||||||
|
}]
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1000,
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": TEST_SCHEMA
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
tool_choice="required")
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1000,
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": TEST_SCHEMA
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
tool_choice="auto")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
||||||
|
async def test_inconsistent_tool_choice_and_tools(
|
||||||
|
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
f"Give an example JSON for an employee profile that "
|
||||||
|
f"fits this schema: {TEST_SCHEMA}"
|
||||||
|
}]
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.chat.completions.create(model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1000,
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name":
|
||||||
|
"dummy_function_name"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=1000,
|
||||||
|
tools=[{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "dummy_function_name",
|
||||||
|
"description": "This is a dummy function",
|
||||||
|
"parameters": TEST_SCHEMA
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
tool_choice={
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "nondefined_function_name"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
|
|||||||
@ -24,7 +24,8 @@ class ServerRunner:
|
|||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env["PYTHONUNBUFFERED"] = "1"
|
env["PYTHONUNBUFFERED"] = "1"
|
||||||
self.proc = subprocess.Popen(
|
self.proc = subprocess.Popen(
|
||||||
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
|
||||||
|
args,
|
||||||
env=env,
|
env=env,
|
||||||
stdout=sys.stdout,
|
stdout=sys.stdout,
|
||||||
stderr=sys.stderr,
|
stderr=sys.stderr,
|
||||||
|
|||||||
@ -102,6 +102,26 @@ class ResponseFormat(OpenAIBaseModel):
|
|||||||
type: Literal["text", "json_object"]
|
type: Literal["text", "json_object"]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDefinition(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||||
|
function: ChatCompletionNamedFunction
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(OpenAIBaseModel):
|
class ChatCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
@ -122,6 +142,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
|
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||||
|
tool_choice: Optional[Union[Literal["none"],
|
||||||
|
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
# doc: begin-chat-completion-sampling-params
|
# doc: begin-chat-completion-sampling-params
|
||||||
@ -245,10 +268,27 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"guided_regex" in data and data["guided_regex"] is not None,
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
"guided_choice" in data and data["guided_choice"] is not None
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
])
|
])
|
||||||
|
# you can only use one kind of guided decoding
|
||||||
if guide_count > 1:
|
if guide_count > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You can only use one kind of guided decoding "
|
"You can only use one kind of guided decoding "
|
||||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
# you can only either use guided decoding or tools, not both
|
||||||
|
if guide_count > 1 and "tool_choice" in data and data[
|
||||||
|
"tool_choice"] != "none":
|
||||||
|
raise ValueError(
|
||||||
|
"You can only either use guided decoding or tools, not both.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_tool_choice(cls, data):
|
||||||
|
if "tool_choice" in data and data["tool_choice"] != "none":
|
||||||
|
if not isinstance(data["tool_choice"], dict):
|
||||||
|
raise ValueError("Currently only named tools are supported.")
|
||||||
|
if "tools" not in data or data["tools"] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When using `tool_choice`, `tools` must be set.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@ -506,9 +546,21 @@ class EmbeddingResponse(BaseModel):
|
|||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(OpenAIBaseModel):
|
class ChatMessage(OpenAIBaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||||
@ -535,7 +587,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
|
|||||||
|
|
||||||
class ChatCompletionResponse(OpenAIBaseModel):
|
class ChatCompletionResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
object: str = "chat.completion"
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
@ -545,6 +597,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
|||||||
class DeltaMessage(OpenAIBaseModel):
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
@ -557,7 +610,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
|||||||
|
|
||||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
object: str = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
|||||||
@ -14,10 +14,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
||||||
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
||||||
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
|
||||||
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
UsageInfo)
|
FunctionCall, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -298,11 +299,24 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
delta_text = output.text[len(previous_texts[i]):]
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
previous_texts[i] = output.text
|
previous_texts[i] = output.text
|
||||||
previous_num_tokens[i] = len(output.token_ids)
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
|
||||||
|
if request.tool_choice and type(
|
||||||
|
request.tool_choice
|
||||||
|
) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name=request.tool_choice.function.name,
|
||||||
|
arguments=delta_text))
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
if output.finish_reason is None:
|
if output.finish_reason is None:
|
||||||
# Send token-by-token response for each request.n
|
# Send token-by-token response for each request.n
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
delta=DeltaMessage(content=delta_text),
|
delta=delta_message,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=None)
|
finish_reason=None)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
@ -324,7 +338,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
)
|
)
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
delta=DeltaMessage(content=delta_text),
|
delta=delta_message,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason)
|
stop_reason=output.stop_reason)
|
||||||
@ -381,9 +395,22 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
|
if request.tool_choice and type(
|
||||||
|
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name=request.tool_choice.function.name,
|
||||||
|
arguments=output.text))
|
||||||
|
])
|
||||||
|
elif not request.tool_choice or request.tool_choice == "none":
|
||||||
|
message = ChatMessage(role=role, content=output.text)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
message=ChatMessage(role=role, content=output.text),
|
message=message,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason)
|
stop_reason=output.stop_reason)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
||||||
CompletionRequest)
|
CompletionRequest)
|
||||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor(
|
|||||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||||
ChatCompletionRequest],
|
ChatCompletionRequest],
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
tokenizer) -> Optional[LogitsProcessor]:
|
||||||
|
request = _adapt_request_for_tool_use(request)
|
||||||
|
|
||||||
if guided_decoding_backend == 'outlines':
|
if guided_decoding_backend == 'outlines':
|
||||||
return await get_outlines_guided_decoding_logits_processor(
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
request, tokenizer)
|
request, tokenizer)
|
||||||
@ -23,3 +26,26 @@ async def get_guided_decoding_logits_processor(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||||
|
|
||||||
|
|
||||||
|
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
|
||||||
|
ChatCompletionRequest]):
|
||||||
|
# the legacy completion API does not support tool use
|
||||||
|
if type(request) is CompletionRequest:
|
||||||
|
return request
|
||||||
|
|
||||||
|
# user has chosen to not use any tool
|
||||||
|
if request.tool_choice == "none":
|
||||||
|
return request
|
||||||
|
|
||||||
|
# user has chosen to use a named tool
|
||||||
|
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
tool_name = request.tool_choice.function.name
|
||||||
|
tools = {tool.function.name: tool.function for tool in request.tools}
|
||||||
|
if tool_name not in tools:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||||
|
tool = tools[tool_name]
|
||||||
|
request.guided_json = tool.parameters
|
||||||
|
|
||||||
|
return request
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user