mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:15:40 +08:00
[Frontend] Respect Chat Completion parallel_tool_calls param (#26233)
Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
parent
a685b47c57
commit
e1dd706cd1
@ -49,7 +49,8 @@ We currently support the following OpenAI APIs:
|
|||||||
- *Note: `suffix` parameter is not supported.*
|
- *Note: `suffix` parameter is not supported.*
|
||||||
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
||||||
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
||||||
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
- *Note: `user` parameter is ignored.*
|
||||||
|
- *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls.
|
||||||
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
||||||
- Only applicable to [embedding models](../models/pooling_models.md).
|
- Only applicable to [embedding models](../models/pooling_models.md).
|
||||||
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
|
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
|
||||||
|
|||||||
@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results(
|
|||||||
assert finish_reason_count == 1
|
assert finish_reason_count == 1
|
||||||
assert len(chunks)
|
assert len(chunks)
|
||||||
assert "".join(chunks) == choice.message.content
|
assert "".join(chunks) == choice.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||||
|
"""
|
||||||
|
Ensure only one tool call is returned when parallel_tool_calls is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
models = await client.models.list()
|
||||||
|
model_name: str = models.data[0].id
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=200,
|
||||||
|
model=model_name,
|
||||||
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||||
|
logprobs=False,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_reason = chat_completion.choices[0].finish_reason
|
||||||
|
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||||
|
|
||||||
|
# make sure only 1 tool call is present
|
||||||
|
assert len(non_streamed_tool_calls) == 1
|
||||||
|
assert stop_reason == "tool_calls"
|
||||||
|
|
||||||
|
# make the same request, streaming
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=200,
|
||||||
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||||
|
logprobs=False,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason_count: int = 0
|
||||||
|
tool_call_id_count: int = 0
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
# if there's a finish reason make sure it's tools
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
finish_reason_count += 1
|
||||||
|
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||||
|
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||||
|
tool_call = streamed_tool_calls[0]
|
||||||
|
if tool_call.id:
|
||||||
|
tool_call_id_count += 1
|
||||||
|
|
||||||
|
# make sure only 1 streaming tool call is present
|
||||||
|
assert tool_call_id_count == 1
|
||||||
|
assert finish_reason_count == 1
|
||||||
|
|||||||
@ -559,9 +559,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
) = "none"
|
) = "none"
|
||||||
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
||||||
include_reasoning: bool = True
|
include_reasoning: bool = True
|
||||||
|
parallel_tool_calls: bool | None = True
|
||||||
|
|
||||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
# NOTE this will be ignored by vLLM
|
||||||
parallel_tool_calls: bool | None = False
|
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|
||||||
# --8<-- [start:chat-completion-sampling-params]
|
# --8<-- [start:chat-completion-sampling-params]
|
||||||
|
|||||||
@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||||
|
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -1206,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
finish_reason_sent[i] = True
|
finish_reason_sent[i] = True
|
||||||
|
|
||||||
|
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
@ -1531,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
as_list(output.token_ids) if request.return_token_ids else None
|
as_list(output.token_ids) if request.return_token_ids else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
|||||||
@ -296,11 +296,7 @@ class OpenAIServing:
|
|||||||
parser = None
|
parser = None
|
||||||
if not enable_auto_tools or tool_parser_name is None:
|
if not enable_auto_tools or tool_parser_name is None:
|
||||||
return parser
|
return parser
|
||||||
logger.info(
|
logger.info('"auto" tool choice has been enabled.')
|
||||||
'"auto" tool choice has been enabled please note that while'
|
|
||||||
" the parallel_tool_calls client option is preset for "
|
|
||||||
"compatibility reasons, it will be ignored."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
|
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
|
||||||
|
|||||||
37
vllm/entrypoints/openai/utils.py
Normal file
37
vllm/entrypoints/openai/utils.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used internally
|
||||||
|
_ChatCompletionResponseChoiceT = TypeVar(
|
||||||
|
"_ChatCompletionResponseChoiceT",
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_filter_parallel_tool_calls(
|
||||||
|
choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
|
||||||
|
) -> _ChatCompletionResponseChoiceT:
|
||||||
|
"""Filter to first tool call only when parallel_tool_calls is False."""
|
||||||
|
|
||||||
|
if request.parallel_tool_calls:
|
||||||
|
return choice
|
||||||
|
|
||||||
|
if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
|
||||||
|
choice.message.tool_calls = choice.message.tool_calls[:1]
|
||||||
|
elif (
|
||||||
|
isinstance(choice, ChatCompletionResponseStreamChoice)
|
||||||
|
and choice.delta.tool_calls
|
||||||
|
):
|
||||||
|
choice.delta.tool_calls = [
|
||||||
|
tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
return choice
|
||||||
Loading…
x
Reference in New Issue
Block a user