diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 23df3963823a..e3280bd15b55 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -49,7 +49,8 @@ We currently support the following OpenAI APIs: - *Note: `suffix` parameter is not supported.* - [Chat Completions API](#chat-api) (`/v1/chat/completions`) - Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template). - - *Note: `parallel_tool_calls` and `user` parameters are ignored.* + - *Note: `user` parameter is ignored.* + - *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls. - [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.md). - [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 9af94a6a64a2..77084ec2d945 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results( assert finish_reason_count == 1 assert len(chunks) assert "".join(chunks) == choice.message.content + + +@pytest.mark.asyncio +async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): + """ + Ensure only one tool call is returned when parallel_tool_calls is False. + """ + + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_completion_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + parallel_tool_calls=False, + ) + + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure only 1 tool call is present + assert len(non_streamed_tool_calls) == 1 + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + parallel_tool_calls=False, + stream=True, + ) + + finish_reason_count: int = 0 + tool_call_id_count: int = 0 + + async for chunk in stream: + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == "tool_calls" + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls and len(streamed_tool_calls) > 0: + tool_call = streamed_tool_calls[0] + if tool_call.id: + tool_call_id_count += 1 + + # make sure only 1 streaming tool call is present + assert tool_call_id_count == 1 + assert finish_reason_count == 1 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c4023a618528..98a385a1dcd5 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -559,9 +559,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ) = "none" reasoning_effort: Literal["low", "medium", "high"] | None = None include_reasoning: bool = True + parallel_tool_calls: bool | None = True - # NOTE this will be ignored by vLLM -- the model determines the behavior - parallel_tool_calls: bool | None = False + # NOTE this will be ignored by vLLM user: str | None = None # --8<-- [start:chat-completion-sampling-params] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2a870dbc3afa..9a7051e0920a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall +from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger @@ -1206,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing): finish_reason_sent[i] = True + choice_data = maybe_filter_parallel_tool_calls(choice_data, request) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, @@ -1531,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing): as_list(output.token_ids) if request.return_token_ids else None ), ) + choice_data = maybe_filter_parallel_tool_calls(choice_data, request) choices.append(choice_data) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 09a135b701d0..d9feee917ff4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -296,11 +296,7 @@ class OpenAIServing: parser = None if not enable_auto_tools or tool_parser_name is None: return parser - logger.info( - '"auto" tool choice has been enabled please note that while' - " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored." - ) + logger.info('"auto" tool choice has been enabled.') try: if tool_parser_name == "pythonic" and self.model_config.model.startswith( diff --git a/vllm/entrypoints/openai/utils.py b/vllm/entrypoints/openai/utils.py new file mode 100644 index 000000000000..6f37f6adff4c --- /dev/null +++ b/vllm/entrypoints/openai/utils.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TypeVar + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, +) + +# Used internally +_ChatCompletionResponseChoiceT = TypeVar( + "_ChatCompletionResponseChoiceT", + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, +) + + +def maybe_filter_parallel_tool_calls( + choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest +) -> _ChatCompletionResponseChoiceT: + """Filter to first tool call only when parallel_tool_calls is False.""" + + if request.parallel_tool_calls: + return choice + + if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls: + choice.message.tool_calls = choice.message.tool_calls[:1] + elif ( + isinstance(choice, ChatCompletionResponseStreamChoice) + and choice.delta.tool_calls + ): + choice.delta.tool_calls = [ + tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0 + ] + + return choice