From 582bbe6bd708d01d74d6d02d6ef59b4c3c34a7b1 Mon Sep 17 00:00:00 2001 From: bigmoyan Date: Thu, 21 Aug 2025 03:59:54 +0800 Subject: [PATCH] [Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259) Co-authored-by: wangzhengtao --- .../test_completion_with_function_calling.py | 314 +++++++++++------- tests/utils.py | 10 +- vllm/entrypoints/chat_utils.py | 17 +- vllm/entrypoints/openai/protocol.py | 4 +- vllm/entrypoints/openai/serving_chat.py | 64 +++- .../tool_parsers/deepseekv3_tool_parser.py | 4 +- .../granite_20b_fc_tool_parser.py | 4 +- .../tool_parsers/granite_tool_parser.py | 4 +- .../openai/tool_parsers/hermes_tool_parser.py | 4 +- .../tool_parsers/internlm2_tool_parser.py | 4 +- .../openai/tool_parsers/jamba_tool_parser.py | 4 +- .../openai/tool_parsers/llama_tool_parser.py | 4 +- .../tool_parsers/minimax_tool_parser.py | 4 +- .../tool_parsers/phi4mini_tool_parser.py | 4 +- .../openai/tool_parsers/xlam_tool_parser.py | 4 +- 15 files changed, 283 insertions(+), 166 deletions(-) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index a5b081f86107..4ef5d4e8a699 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen3-0.6B" +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": "Optional parameters for weather query", + }, + }, + "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, +] + +messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, +] + @pytest.fixture(scope="module") def server(): # noqa: F811 @@ -27,6 +148,8 @@ def server(): # noqa: F811 "hermes", "--reasoning-parser", "qwen3", + "--gpu-memory-utilization", + "0.4" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -54,129 +177,6 @@ async def client(server): async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: Union[str, dict], enable_thinking: bool): - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - "options": { - "$ref": "#/$defs/WeatherOptions", - "description": - "Optional parameters for weather query", - }, - }, - "required": ["country", "unit"], - "$defs": { - "WeatherOptions": { - "title": "WeatherOptions", - "type": "object", - "additionalProperties": False, - "properties": { - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "default": "celsius", - "description": "Temperature unit", - "title": "Temperature Unit", - }, - "include_forecast": { - "type": "boolean", - "default": False, - "description": - "Whether to include a 24-hour forecast", - "title": "Include Forecast", - }, - "language": { - "type": "string", - "default": "zh-CN", - "description": "Language of the response", - "title": "Language", - "enum": ["zh-CN", "en-US", "ja-JP"], - }, - }, - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "get_forecast", - "description": "Get the weather forecast for a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", - "default": "Vienna", - }, - "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", - }, - "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", - }, - "unit": { - "type": "string", - "description": - "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"], - }, - }, - "required": ["country", "days", "unit"], - }, - }, - }, - ] - - messages = [ - { - "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ - "forecast for the next 5 days, in fahrenheit?", - }, - ] if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, output.extend(chunk.choices[0].delta.tool_calls) assert len(output) > 0 + + +@pytest.fixture(scope="module") +def k2_server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes", + "--reasoning-parser", + "qwen3", + "--gpu-memory-utilization", + "0.4", + ] + # hack to test kimi_k2 tool use tool_id format. + # avoid error in is_deepseek_mla check by setting kv_lora_rank=null + with RemoteOpenAIServer(MODEL_NAME, + args, + override_hf_configs={ + "model_type": 'kimi_k2', + 'kv_lora_rank': None + }) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def k2_client(k2_server): + async with k2_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("tool_choice", ["required"]) +async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, + stream: bool, tool_choice: str): + + if not stream: + # Non-streaming test + chat_completion = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice) + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + assert chat_completion.choices[0].message.tool_calls[ + 0].id == 'functions.get_current_weather:0' + else: + # Streaming test + output_stream = await k2_client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice=tool_choice, + stream=True) + + output = [] + async for chunk in output_stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + for o in output: + assert o.id is None or o.id == 'functions.get_current_weather:0' diff --git a/tests/utils.py b/tests/utils.py index e98707fb4447..4dba5494665a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ import asyncio import copy import functools import importlib +import json import os import signal import subprocess @@ -101,7 +102,8 @@ class RemoteOpenAIServer: env_dict: Optional[dict[str, str]] = None, seed: Optional[int] = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: raise ValueError("You have manually specified the port " @@ -120,6 +122,12 @@ class RemoteOpenAIServer: vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 74c8093f4967..87772a499f42 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1345,5 +1345,18 @@ def apply_mistral_chat_template( "template") raise ValueError(str(e)) from e -def random_tool_call_id() -> str: - return f"chatcmpl-tool-{random_uuid()}" +def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): + idx = 0 + for msg in conversation: + if msg['role'] == 'assistant': + tool_calls = msg.get('tool_calls') + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + return idx + +def make_tool_call_id(id_type:str='random', func_name=None, idx=None): + + if id_type=='kimi_k2': + return f'functions.{func_name}:{idx}' + else: + # by default return random + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 39facd4d53d3..a44868973f5d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -38,7 +38,7 @@ from typing_extensions import TypeAlias from vllm import envs from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - random_tool_call_id) + make_tool_call_id) from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam) from vllm.logger import init_logger @@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=random_tool_call_id) + id: str = Field(default_factory=make_tool_call_id) type: Literal["function"] = "function" function: FunctionCall diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d57868847eed..65aac23ee618 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -19,7 +19,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, - random_tool_call_id) + get_history_tool_calls_cnt, + make_tool_call_id) from vllm.entrypoints.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, get_streamable_parser_for_assistant, get_system_message, parse_chat_input, @@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing): source = "model" if source == "auto" else source logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + if self.model_config.hf_config.model_type == 'kimi_k2': + self.tool_call_id_type = 'kimi_k2' + else: + self.tool_call_id_type = 'random' self.use_harmony = model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: @@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing): current_text: Optional[str], delta_text: str, function_name_returned: bool, + tool_call_idx: Optional[int] = None ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing): current_tool_call = obj[-2] function_name_returned = True + tool_call_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=current_tool_call["name"], + idx=tool_call_idx) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=random_tool_call_id(), + DeltaToolCall(id=tool_call_id, function=DeltaFunctionCall( name=current_tool_call["name"], arguments=arguments), @@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing): all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 # Always track previous_texts for comprehensive output logging previous_texts = [""] * num_choices @@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing): previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - # avoid the None + list error. if previous_token_ids: current_token_ids = previous_token_ids + as_list( @@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing): index=i) else: delta_tool_call = DeltaToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, @@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing): previous_text=previous_text, current_text=content, delta_text=delta_text, - function_name_returned=fn_name_returned)) + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt)) + if (delta_message and delta_message.tool_calls and + delta_message.tool_calls[0].id is not None): + history_tool_call_cnt += 1 # update the previous values for the next iteration previous_texts[i] = current_text @@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing): assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] + if self.tool_call_id_type == 'kimi_k2': + history_tool_call_cnt = get_history_tool_calls_cnt(conversation) + else: + history_tool_call_cnt = 0 role = self.get_chat_request_role(request) for output in final_res.outputs: @@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing): assert content is not None tool_calls = TypeAdapter( list[FunctionDefinition]).validate_json(content) + tool_call_ids = [] + for tool_call in tool_calls: + tool_call_ids.append( + make_tool_call_id(id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt)) + history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", - reasoning_content=reasoning_content, tool_calls=[ - tool_call_class(function=FunctionCall( - name=tool_call.name, - arguments=json.dumps(tool_call.parameters, - ensure_ascii=False))) - for tool_call in tool_calls - ]) + tool_call_class(id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, + ensure_ascii=False))) + for i, tool_call in enumerate(tool_calls) + ], + reasoning_content=reasoning_content) # if the request doesn't use tool choice # OR specifies to not use a tool @@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing): if (tool_call_info.content and len(tool_call_info.content) > 0): ret_content = tool_call_info.content - message = ChatMessage(role=role, reasoning_content=reasoning_content, content=ret_content) @@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing): elif choice.message.tool_calls: # For tool calls, log the function name and arguments tool_call_descriptions = [] - for tool_call in choice.message.tool_calls: - if hasattr(tool_call.function, "name") and hasattr( - tool_call.function, "arguments"): + for tc in choice.message.tool_calls: + if hasattr(tc.function, "name") and hasattr( + tc.function, "arguments"): tool_call_descriptions.append( - f"{tool_call.function.name}({tool_call.function.arguments})" - ) + f"{tc.function.name}({tc.function.arguments})") tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index da4760ad1b64..ac272b0c3b20 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -6,7 +6,7 @@ from typing import Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser): DeltaToolCall( index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True), diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 5508ba6a3940..824b100f357b 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -10,7 +10,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index fcc5b7edda83..ac517616a95b 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index d126130ab9bc..a6ce33af6bd0 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 92004de030d1..6ef8fadf59ac 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -8,7 +8,7 @@ from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 66b483d8b0f6..3b41f6034704 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -9,7 +9,7 @@ import partial_json_parser import regex as re from partial_json_parser.core.options import Allow -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -222,7 +222,7 @@ class JambaToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 194a144ad576..31b19c8db416 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -10,7 +10,7 @@ import regex as re from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser): delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=random_tool_call_id(), + id=make_tool_call_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 226309ef293a..283e6095013d 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser): sent_tools.append({ "sent_name": False, "sent_arguments": "", - "id": random_tool_call_id(), + "id": make_tool_call_id(), }) while len(tool_ids) < tool_count: diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 5501028cf36b..85dd56213c6a 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -8,7 +8,7 @@ from typing import Any, Optional import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, @@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser): tool_calls: list[ToolCall] = [ ToolCall( - id=random_tool_call_id(), + id=make_tool_call_id(), type="function", function=FunctionCall( name=raw_function_call["name"], diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 321718b1c950..87cd413b3720 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -7,7 +7,7 @@ from typing import Any, Optional, Union import regex as re -from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser): function_name = name_match.group(1) # The test expects us to send just the name first - tool_id = random_tool_call_id() + tool_id = make_tool_call_id() delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=0,