[Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259)

Co-authored-by: wangzhengtao <wangzhengtao@msh.team>
This commit is contained in:
bigmoyan 2025-08-21 03:59:54 +08:00 committed by GitHub
parent 0cdbf5e61c
commit 582bbe6bd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 283 additions and 166 deletions

View File

@ -13,48 +13,7 @@ from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen3-0.6B" MODEL_NAME = "Qwen/Qwen3-0.6B"
tools = [
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", [
"auto", "required", {
"type": "function",
"function": {
"name": "get_current_weather"
}
}
])
@pytest.mark.parametrize("enable_thinking", [True, False])
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: Union[str, dict],
enable_thinking: bool):
tools = [
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -77,14 +36,12 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"description": "description": "The unit to fetch the temperature in",
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
"options": { "options": {
"$ref": "#/$defs/WeatherOptions", "$ref": "#/$defs/WeatherOptions",
"description": "description": "Optional parameters for weather query",
"Optional parameters for weather query",
}, },
}, },
"required": ["country", "unit"], "required": ["country", "unit"],
@ -149,8 +106,7 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"description": "description": "The unit to fetch the temperature in",
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
}, },
@ -158,9 +114,9 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
}, },
}, },
}, },
] ]
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": "Hi! How are you doing today?" "content": "Hi! How are you doing today?"
@ -176,7 +132,51 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
"Can you tell me what the current weather is in Berlin and the "\ "Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?", "forecast for the next 5 days, in fahrenheit?",
}, },
]
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
"--gpu-memory-utilization",
"0.4"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", [
"auto", "required", {
"type": "function",
"function": {
"name": "get_current_weather"
}
}
])
@pytest.mark.parametrize("enable_thinking", [True, False])
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: Union[str, dict],
enable_thinking: bool):
if not stream: if not stream:
# Non-streaming test # Non-streaming test
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
output.extend(chunk.choices[0].delta.tool_calls) output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0 assert len(output) > 0
@pytest.fixture(scope="module")
def k2_server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
"--gpu-memory-utilization",
"0.4",
]
# hack to test kimi_k2 tool use tool_id format.
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
with RemoteOpenAIServer(MODEL_NAME,
args,
override_hf_configs={
"model_type": 'kimi_k2',
'kv_lora_rank': None
}) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def k2_client(k2_server):
async with k2_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", ["required"])
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: str):
if not stream:
# Non-streaming test
chat_completion = await k2_client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice=tool_choice)
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
assert chat_completion.choices[0].message.tool_calls[
0].id == 'functions.get_current_weather:0'
else:
# Streaming test
output_stream = await k2_client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice=tool_choice,
stream=True)
output = []
async for chunk in output_stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
for o in output:
assert o.id is None or o.id == 'functions.get_current_weather:0'

View File

@ -5,6 +5,7 @@ import asyncio
import copy import copy
import functools import functools
import importlib import importlib
import json
import os import os
import signal import signal
import subprocess import subprocess
@ -101,7 +102,8 @@ class RemoteOpenAIServer:
env_dict: Optional[dict[str, str]] = None, env_dict: Optional[dict[str, str]] = None,
seed: Optional[int] = 0, seed: Optional[int] = 0,
auto_port: bool = True, auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None: max_wait_seconds: Optional[float] = None,
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
if auto_port: if auto_port:
if "-p" in vllm_serve_args or "--port" in vllm_serve_args: if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port " raise ValueError("You have manually specified the port "
@ -120,6 +122,12 @@ class RemoteOpenAIServer:
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
if override_hf_configs is not None:
vllm_serve_args = vllm_serve_args + [
"--hf-overrides",
json.dumps(override_hf_configs)
]
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.") description="vLLM's remote OpenAI server.")
subparsers = parser.add_subparsers(required=False, dest="subparser") subparsers = parser.add_subparsers(required=False, dest="subparser")

View File

@ -1345,5 +1345,18 @@ def apply_mistral_chat_template(
"template") "template")
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
def random_tool_call_id() -> str: def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
idx = 0
for msg in conversation:
if msg['role'] == 'assistant':
tool_calls = msg.get('tool_calls')
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
if id_type=='kimi_k2':
return f'functions.{func_name}:{idx}'
else:
# by default return random
return f"chatcmpl-tool-{random_uuid()}" return f"chatcmpl-tool-{random_uuid()}"

View File

@ -38,7 +38,7 @@ from typing_extensions import TypeAlias
from vllm import envs from vllm import envs
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
random_tool_call_id) make_tool_call_id)
from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam) ScoreMultiModalParam)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel):
class ToolCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=random_tool_call_id) id: str = Field(default_factory=make_tool_call_id)
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: FunctionCall function: FunctionCall

View File

@ -19,7 +19,8 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage, ConversationMessage,
random_tool_call_id) get_history_tool_calls_cnt,
make_tool_call_id)
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions, get_developer_message, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_system_message, parse_chat_input, get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing):
source = "model" if source == "auto" else source source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s", logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params) source, self.default_sampling_params)
if self.model_config.hf_config.model_type == 'kimi_k2':
self.tool_call_id_type = 'kimi_k2'
else:
self.tool_call_id_type = 'random'
self.use_harmony = model_config.hf_config.model_type == "gpt_oss" self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing):
current_text: Optional[str], current_text: Optional[str],
delta_text: str, delta_text: str,
function_name_returned: bool, function_name_returned: bool,
tool_call_idx: Optional[int] = None
) -> tuple[Optional[DeltaMessage], bool]: ) -> tuple[Optional[DeltaMessage], bool]:
if current_text is None or current_text == "": if current_text is None or current_text == "":
# if the current text is empty, we cannot parse it # if the current text is empty, we cannot parse it
@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing):
current_tool_call = obj[-2] current_tool_call = obj[-2]
function_name_returned = True function_name_returned = True
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=current_tool_call["name"],
idx=tool_call_idx)
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(id=random_tool_call_id(), DeltaToolCall(id=tool_call_id,
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=current_tool_call["name"], name=current_tool_call["name"],
arguments=arguments), arguments=arguments),
@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing):
all_previous_token_ids: Optional[list[list[int]]] all_previous_token_ids: Optional[list[list[int]]]
function_name_returned = [False] * num_choices function_name_returned = [False] * num_choices
if self.tool_call_id_type == 'kimi_k2':
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
history_tool_call_cnt = 0
# Always track previous_texts for comprehensive output logging # Always track previous_texts for comprehensive output logging
previous_texts = [""] * num_choices previous_texts = [""] * num_choices
@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing):
previous_text = previous_texts[i] previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i] previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text current_text = previous_text + delta_text
# avoid the None + list error. # avoid the None + list error.
if previous_token_ids: if previous_token_ids:
current_token_ids = previous_token_ids + as_list( current_token_ids = previous_token_ids + as_list(
@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing):
index=i) index=i)
else: else:
delta_tool_call = DeltaToolCall( delta_tool_call = DeltaToolCall(
id=random_tool_call_id(), id=make_tool_call_id(),
type="function", type="function",
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=tool_choice_function_name, name=tool_choice_function_name,
@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing):
previous_text=previous_text, previous_text=previous_text,
current_text=content, current_text=content,
delta_text=delta_text, delta_text=delta_text,
function_name_returned=fn_name_returned)) function_name_returned=fn_name_returned,
tool_call_idx=history_tool_call_cnt))
if (delta_message and delta_message.tool_calls and
delta_message.tool_calls[0].id is not None):
history_tool_call_cnt += 1
# update the previous values for the next iteration # update the previous values for the next iteration
previous_texts[i] = current_text previous_texts[i] = current_text
@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing):
assert final_res is not None assert final_res is not None
choices: list[ChatCompletionResponseChoice] = [] choices: list[ChatCompletionResponseChoice] = []
if self.tool_call_id_type == 'kimi_k2':
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
history_tool_call_cnt = 0
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for output in final_res.outputs: for output in final_res.outputs:
@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing):
assert content is not None assert content is not None
tool_calls = TypeAdapter( tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(content) list[FunctionDefinition]).validate_json(content)
tool_call_ids = []
for tool_call in tool_calls:
tool_call_ids.append(
make_tool_call_id(id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt))
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
content="", content="",
reasoning_content=reasoning_content,
tool_calls=[ tool_calls=[
tool_call_class(function=FunctionCall( tool_call_class(id=tool_call_ids[i],
function=FunctionCall(
name=tool_call.name, name=tool_call.name,
arguments=json.dumps(tool_call.parameters, arguments=json.dumps(
tool_call.parameters,
ensure_ascii=False))) ensure_ascii=False)))
for tool_call in tool_calls for i, tool_call in enumerate(tool_calls)
]) ],
reasoning_content=reasoning_content)
# if the request doesn't use tool choice # if the request doesn't use tool choice
# OR specifies to not use a tool # OR specifies to not use a tool
@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing):
if (tool_call_info.content if (tool_call_info.content
and len(tool_call_info.content) > 0): and len(tool_call_info.content) > 0):
ret_content = tool_call_info.content ret_content = tool_call_info.content
message = ChatMessage(role=role, message = ChatMessage(role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=ret_content) content=ret_content)
@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls: elif choice.message.tool_calls:
# For tool calls, log the function name and arguments # For tool calls, log the function name and arguments
tool_call_descriptions = [] tool_call_descriptions = []
for tool_call in choice.message.tool_calls: for tc in choice.message.tool_calls:
if hasattr(tool_call.function, "name") and hasattr( if hasattr(tc.function, "name") and hasattr(
tool_call.function, "arguments"): tc.function, "arguments"):
tool_call_descriptions.append( tool_call_descriptions.append(
f"{tool_call.function.name}({tool_call.function.arguments})" f"{tc.function.name}({tc.function.arguments})")
)
tool_calls_str = ", ".join(tool_call_descriptions) tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]" output_text = f"[tool_calls: {tool_calls_str}]"

View File

@ -6,7 +6,7 @@ from typing import Union
import regex as re import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser):
DeltaToolCall( DeltaToolCall(
index=self.current_tool_id, index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True), exclude_none=True),

View File

@ -10,7 +10,7 @@ import partial_json_parser
import regex as re import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))

View File

@ -8,7 +8,7 @@ from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))

View File

@ -9,7 +9,7 @@ import partial_json_parser
import regex as re import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser):
return DeltaMessage(tool_calls=[ return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))

View File

@ -8,7 +8,7 @@ from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))

View File

@ -9,7 +9,7 @@ import partial_json_parser
import regex as re import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -222,7 +222,7 @@ class JambaToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))

View File

@ -10,7 +10,7 @@ import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))

View File

@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser):
sent_tools.append({ sent_tools.append({
"sent_name": False, "sent_name": False,
"sent_arguments": "", "sent_arguments": "",
"id": random_tool_call_id(), "id": make_tool_call_id(),
}) })
while len(tool_ids) < tool_count: while len(tool_ids) < tool_count:

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, DeltaMessage,
ExtractedToolCallInformation, ExtractedToolCallInformation,
@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
tool_calls: list[ToolCall] = [ tool_calls: list[ToolCall] = [
ToolCall( ToolCall(
id=random_tool_call_id(), id=make_tool_call_id(),
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=raw_function_call["name"], name=raw_function_call["name"],

View File

@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser):
function_name = name_match.group(1) function_name = name_match.group(1)
# The test expects us to send just the name first # The test expects us to send just the name first
tool_id = random_tool_call_id() tool_id = make_tool_call_id()
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall( DeltaToolCall(
index=0, index=0,