mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[Frontend] Respect Chat Completion parallel_tool_calls param (#26233)
Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
parent
a685b47c57
commit
e1dd706cd1
@ -49,7 +49,8 @@ We currently support the following OpenAI APIs:
|
||||
- *Note: `suffix` parameter is not supported.*
|
||||
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
||||
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
||||
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
||||
- *Note: `user` parameter is ignored.*
|
||||
- *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls.
|
||||
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
||||
- Only applicable to [embedding models](../models/pooling_models.md).
|
||||
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
|
||||
|
||||
@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results(
|
||||
assert finish_reason_count == 1
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == choice.message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||
"""
|
||||
Ensure only one tool call is returned when parallel_tool_calls is False.
|
||||
"""
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
parallel_tool_calls=False,
|
||||
)
|
||||
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure only 1 tool call is present
|
||||
assert len(non_streamed_tool_calls) == 1
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
parallel_tool_calls=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
finish_reason_count: int = 0
|
||||
tool_call_id_count: int = 0
|
||||
|
||||
async for chunk in stream:
|
||||
# if there's a finish reason make sure it's tools
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
tool_call = streamed_tool_calls[0]
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
|
||||
# make sure only 1 streaming tool call is present
|
||||
assert tool_call_id_count == 1
|
||||
assert finish_reason_count == 1
|
||||
|
||||
@ -559,9 +559,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
) = "none"
|
||||
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
||||
include_reasoning: bool = True
|
||||
parallel_tool_calls: bool | None = True
|
||||
|
||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
||||
parallel_tool_calls: bool | None = False
|
||||
# NOTE this will be ignored by vLLM
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-completion-sampling-params]
|
||||
|
||||
@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
@ -1206,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
@ -1531,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
as_list(output.token_ids) if request.return_token_ids else None
|
||||
),
|
||||
)
|
||||
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
@ -296,11 +296,7 @@ class OpenAIServing:
|
||||
parser = None
|
||||
if not enable_auto_tools or tool_parser_name is None:
|
||||
return parser
|
||||
logger.info(
|
||||
'"auto" tool choice has been enabled please note that while'
|
||||
" the parallel_tool_calls client option is preset for "
|
||||
"compatibility reasons, it will be ignored."
|
||||
)
|
||||
logger.info('"auto" tool choice has been enabled.')
|
||||
|
||||
try:
|
||||
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
|
||||
|
||||
37
vllm/entrypoints/openai/utils.py
Normal file
37
vllm/entrypoints/openai/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TypeVar
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
)
|
||||
|
||||
# Used internally
|
||||
_ChatCompletionResponseChoiceT = TypeVar(
|
||||
"_ChatCompletionResponseChoiceT",
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
)
|
||||
|
||||
|
||||
def maybe_filter_parallel_tool_calls(
|
||||
choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
|
||||
) -> _ChatCompletionResponseChoiceT:
|
||||
"""Filter to first tool call only when parallel_tool_calls is False."""
|
||||
|
||||
if request.parallel_tool_calls:
|
||||
return choice
|
||||
|
||||
if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
|
||||
choice.message.tool_calls = choice.message.tool_calls[:1]
|
||||
elif (
|
||||
isinstance(choice, ChatCompletionResponseStreamChoice)
|
||||
and choice.delta.tool_calls
|
||||
):
|
||||
choice.delta.tool_calls = [
|
||||
tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
|
||||
]
|
||||
|
||||
return choice
|
||||
Loading…
x
Reference in New Issue
Block a user