mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 12:35:01 +08:00
[Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259)
Co-authored-by: wangzhengtao <wangzhengtao@msh.team>
This commit is contained in:
parent
0cdbf5e61c
commit
582bbe6bd7
@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer
|
|||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description":
|
||||||
|
"The city to find the weather for, e.g. 'Vienna'",
|
||||||
|
"default": "Vienna",
|
||||||
|
},
|
||||||
|
"country": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"The country that the city is in, e.g. 'Austria'",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The unit to fetch the temperature in",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"$ref": "#/$defs/WeatherOptions",
|
||||||
|
"description": "Optional parameters for weather query",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["country", "unit"],
|
||||||
|
"$defs": {
|
||||||
|
"WeatherOptions": {
|
||||||
|
"title": "WeatherOptions",
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": False,
|
||||||
|
"properties": {
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"default": "celsius",
|
||||||
|
"description": "Temperature unit",
|
||||||
|
"title": "Temperature Unit",
|
||||||
|
},
|
||||||
|
"include_forecast": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": False,
|
||||||
|
"description":
|
||||||
|
"Whether to include a 24-hour forecast",
|
||||||
|
"title": "Include Forecast",
|
||||||
|
},
|
||||||
|
"language": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "zh-CN",
|
||||||
|
"description": "Language of the response",
|
||||||
|
"title": "Language",
|
||||||
|
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_forecast",
|
||||||
|
"description": "Get the weather forecast for a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description":
|
||||||
|
"The city to get the forecast for, e.g. 'Vienna'",
|
||||||
|
"default": "Vienna",
|
||||||
|
},
|
||||||
|
"country": {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"The country that the city is in, e.g. 'Austria'",
|
||||||
|
},
|
||||||
|
"days": {
|
||||||
|
"type":
|
||||||
|
"integer",
|
||||||
|
"description":
|
||||||
|
"Number of days to get the forecast for (1-7)",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The unit to fetch the temperature in",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["country", "days", "unit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi! How are you doing today?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'm doing well! How can I help you?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Can you tell me what the current weather is in Berlin and the "\
|
||||||
|
"forecast for the next 5 days, in fahrenheit?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server(): # noqa: F811
|
def server(): # noqa: F811
|
||||||
@ -27,6 +148,8 @@ def server(): # noqa: F811
|
|||||||
"hermes",
|
"hermes",
|
||||||
"--reasoning-parser",
|
"--reasoning-parser",
|
||||||
"qwen3",
|
"qwen3",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.4"
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -54,129 +177,6 @@ async def client(server):
|
|||||||
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||||
stream: bool, tool_choice: Union[str, dict],
|
stream: bool, tool_choice: Union[str, dict],
|
||||||
enable_thinking: bool):
|
enable_thinking: bool):
|
||||||
tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_current_weather",
|
|
||||||
"description": "Get the current weather in a given location",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description":
|
|
||||||
"The city to find the weather for, e.g. 'Vienna'",
|
|
||||||
"default": "Vienna",
|
|
||||||
},
|
|
||||||
"country": {
|
|
||||||
"type":
|
|
||||||
"string",
|
|
||||||
"description":
|
|
||||||
"The country that the city is in, e.g. 'Austria'",
|
|
||||||
},
|
|
||||||
"unit": {
|
|
||||||
"type": "string",
|
|
||||||
"description":
|
|
||||||
"The unit to fetch the temperature in",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
},
|
|
||||||
"options": {
|
|
||||||
"$ref": "#/$defs/WeatherOptions",
|
|
||||||
"description":
|
|
||||||
"Optional parameters for weather query",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["country", "unit"],
|
|
||||||
"$defs": {
|
|
||||||
"WeatherOptions": {
|
|
||||||
"title": "WeatherOptions",
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": False,
|
|
||||||
"properties": {
|
|
||||||
"unit": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
"default": "celsius",
|
|
||||||
"description": "Temperature unit",
|
|
||||||
"title": "Temperature Unit",
|
|
||||||
},
|
|
||||||
"include_forecast": {
|
|
||||||
"type": "boolean",
|
|
||||||
"default": False,
|
|
||||||
"description":
|
|
||||||
"Whether to include a 24-hour forecast",
|
|
||||||
"title": "Include Forecast",
|
|
||||||
},
|
|
||||||
"language": {
|
|
||||||
"type": "string",
|
|
||||||
"default": "zh-CN",
|
|
||||||
"description": "Language of the response",
|
|
||||||
"title": "Language",
|
|
||||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_forecast",
|
|
||||||
"description": "Get the weather forecast for a given location",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description":
|
|
||||||
"The city to get the forecast for, e.g. 'Vienna'",
|
|
||||||
"default": "Vienna",
|
|
||||||
},
|
|
||||||
"country": {
|
|
||||||
"type":
|
|
||||||
"string",
|
|
||||||
"description":
|
|
||||||
"The country that the city is in, e.g. 'Austria'",
|
|
||||||
},
|
|
||||||
"days": {
|
|
||||||
"type":
|
|
||||||
"integer",
|
|
||||||
"description":
|
|
||||||
"Number of days to get the forecast for (1-7)",
|
|
||||||
},
|
|
||||||
"unit": {
|
|
||||||
"type": "string",
|
|
||||||
"description":
|
|
||||||
"The unit to fetch the temperature in",
|
|
||||||
"enum": ["celsius", "fahrenheit"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["country", "days", "unit"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hi! How are you doing today?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "I'm doing well! How can I help you?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role":
|
|
||||||
"user",
|
|
||||||
"content":
|
|
||||||
"Can you tell me what the current weather is in Berlin and the "\
|
|
||||||
"forecast for the next 5 days, in fahrenheit?",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
if not stream:
|
if not stream:
|
||||||
# Non-streaming test
|
# Non-streaming test
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
|||||||
output.extend(chunk.choices[0].delta.tool_calls)
|
output.extend(chunk.choices[0].delta.tool_calls)
|
||||||
|
|
||||||
assert len(output) > 0
|
assert len(output) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def k2_server(): # noqa: F811
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"half",
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
"--guided-decoding-backend",
|
||||||
|
"xgrammar",
|
||||||
|
"--tool-call-parser",
|
||||||
|
"hermes",
|
||||||
|
"--reasoning-parser",
|
||||||
|
"qwen3",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.4",
|
||||||
|
]
|
||||||
|
# hack to test kimi_k2 tool use tool_id format.
|
||||||
|
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME,
|
||||||
|
args,
|
||||||
|
override_hf_configs={
|
||||||
|
"model_type": 'kimi_k2',
|
||||||
|
'kv_lora_rank': None
|
||||||
|
}) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def k2_client(k2_server):
|
||||||
|
async with k2_server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("stream", [True, False])
|
||||||
|
@pytest.mark.parametrize("tool_choice", ["required"])
|
||||||
|
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
stream: bool, tool_choice: str):
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# Non-streaming test
|
||||||
|
chat_completion = await k2_client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model_name,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice)
|
||||||
|
assert chat_completion.choices[0].message.tool_calls is not None
|
||||||
|
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||||
|
assert chat_completion.choices[0].message.tool_calls[
|
||||||
|
0].id == 'functions.get_current_weather:0'
|
||||||
|
else:
|
||||||
|
# Streaming test
|
||||||
|
output_stream = await k2_client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model_name,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
stream=True)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
async for chunk in output_stream:
|
||||||
|
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||||
|
output.extend(chunk.choices[0].delta.tool_calls)
|
||||||
|
for o in output:
|
||||||
|
assert o.id is None or o.id == 'functions.get_current_weather:0'
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -101,7 +102,8 @@ class RemoteOpenAIServer:
|
|||||||
env_dict: Optional[dict[str, str]] = None,
|
env_dict: Optional[dict[str, str]] = None,
|
||||||
seed: Optional[int] = 0,
|
seed: Optional[int] = 0,
|
||||||
auto_port: bool = True,
|
auto_port: bool = True,
|
||||||
max_wait_seconds: Optional[float] = None) -> None:
|
max_wait_seconds: Optional[float] = None,
|
||||||
|
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
|
||||||
if auto_port:
|
if auto_port:
|
||||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||||
raise ValueError("You have manually specified the port "
|
raise ValueError("You have manually specified the port "
|
||||||
@ -120,6 +122,12 @@ class RemoteOpenAIServer:
|
|||||||
|
|
||||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||||
|
|
||||||
|
if override_hf_configs is not None:
|
||||||
|
vllm_serve_args = vllm_serve_args + [
|
||||||
|
"--hf-overrides",
|
||||||
|
json.dumps(override_hf_configs)
|
||||||
|
]
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM's remote OpenAI server.")
|
description="vLLM's remote OpenAI server.")
|
||||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||||
|
|||||||
@ -1345,5 +1345,18 @@ def apply_mistral_chat_template(
|
|||||||
"template")
|
"template")
|
||||||
raise ValueError(str(e)) from e
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
def random_tool_call_id() -> str:
|
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
||||||
return f"chatcmpl-tool-{random_uuid()}"
|
idx = 0
|
||||||
|
for msg in conversation:
|
||||||
|
if msg['role'] == 'assistant':
|
||||||
|
tool_calls = msg.get('tool_calls')
|
||||||
|
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
|
||||||
|
|
||||||
|
if id_type=='kimi_k2':
|
||||||
|
return f'functions.{func_name}:{idx}'
|
||||||
|
else:
|
||||||
|
# by default return random
|
||||||
|
return f"chatcmpl-tool-{random_uuid()}"
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from typing_extensions import TypeAlias
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
random_tool_call_id)
|
make_tool_call_id)
|
||||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||||
ScoreMultiModalParam)
|
ScoreMultiModalParam)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ToolCall(OpenAIBaseModel):
|
class ToolCall(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=random_tool_call_id)
|
id: str = Field(default_factory=make_tool_call_id)
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: FunctionCall
|
function: FunctionCall
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,8 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
random_tool_call_id)
|
get_history_tool_calls_cnt,
|
||||||
|
make_tool_call_id)
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||||
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
||||||
@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
source = "model" if source == "auto" else source
|
source = "model" if source == "auto" else source
|
||||||
logger.info("Using default chat sampling params from %s: %s",
|
logger.info("Using default chat sampling params from %s: %s",
|
||||||
source, self.default_sampling_params)
|
source, self.default_sampling_params)
|
||||||
|
if self.model_config.hf_config.model_type == 'kimi_k2':
|
||||||
|
self.tool_call_id_type = 'kimi_k2'
|
||||||
|
else:
|
||||||
|
self.tool_call_id_type = 'random'
|
||||||
|
|
||||||
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_text: Optional[str],
|
current_text: Optional[str],
|
||||||
delta_text: str,
|
delta_text: str,
|
||||||
function_name_returned: bool,
|
function_name_returned: bool,
|
||||||
|
tool_call_idx: Optional[int] = None
|
||||||
) -> tuple[Optional[DeltaMessage], bool]:
|
) -> tuple[Optional[DeltaMessage], bool]:
|
||||||
if current_text is None or current_text == "":
|
if current_text is None or current_text == "":
|
||||||
# if the current text is empty, we cannot parse it
|
# if the current text is empty, we cannot parse it
|
||||||
@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_tool_call = obj[-2]
|
current_tool_call = obj[-2]
|
||||||
|
|
||||||
function_name_returned = True
|
function_name_returned = True
|
||||||
|
tool_call_id = make_tool_call_id(
|
||||||
|
id_type=self.tool_call_id_type,
|
||||||
|
func_name=current_tool_call["name"],
|
||||||
|
idx=tool_call_idx)
|
||||||
delta_message = DeltaMessage(tool_calls=[
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(id=random_tool_call_id(),
|
DeltaToolCall(id=tool_call_id,
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=current_tool_call["name"],
|
name=current_tool_call["name"],
|
||||||
arguments=arguments),
|
arguments=arguments),
|
||||||
@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
all_previous_token_ids: Optional[list[list[int]]]
|
all_previous_token_ids: Optional[list[list[int]]]
|
||||||
function_name_returned = [False] * num_choices
|
function_name_returned = [False] * num_choices
|
||||||
|
if self.tool_call_id_type == 'kimi_k2':
|
||||||
|
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
|
||||||
|
else:
|
||||||
|
history_tool_call_cnt = 0
|
||||||
|
|
||||||
# Always track previous_texts for comprehensive output logging
|
# Always track previous_texts for comprehensive output logging
|
||||||
previous_texts = [""] * num_choices
|
previous_texts = [""] * num_choices
|
||||||
@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
previous_text = previous_texts[i]
|
previous_text = previous_texts[i]
|
||||||
previous_token_ids = all_previous_token_ids[i]
|
previous_token_ids = all_previous_token_ids[i]
|
||||||
current_text = previous_text + delta_text
|
current_text = previous_text + delta_text
|
||||||
|
|
||||||
# avoid the None + list error.
|
# avoid the None + list error.
|
||||||
if previous_token_ids:
|
if previous_token_ids:
|
||||||
current_token_ids = previous_token_ids + as_list(
|
current_token_ids = previous_token_ids + as_list(
|
||||||
@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
index=i)
|
index=i)
|
||||||
else:
|
else:
|
||||||
delta_tool_call = DeltaToolCall(
|
delta_tool_call = DeltaToolCall(
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
type="function",
|
type="function",
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=tool_choice_function_name,
|
name=tool_choice_function_name,
|
||||||
@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
previous_text=previous_text,
|
previous_text=previous_text,
|
||||||
current_text=content,
|
current_text=content,
|
||||||
delta_text=delta_text,
|
delta_text=delta_text,
|
||||||
function_name_returned=fn_name_returned))
|
function_name_returned=fn_name_returned,
|
||||||
|
tool_call_idx=history_tool_call_cnt))
|
||||||
|
if (delta_message and delta_message.tool_calls and
|
||||||
|
delta_message.tool_calls[0].id is not None):
|
||||||
|
history_tool_call_cnt += 1
|
||||||
|
|
||||||
# update the previous values for the next iteration
|
# update the previous values for the next iteration
|
||||||
previous_texts[i] = current_text
|
previous_texts[i] = current_text
|
||||||
@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
choices: list[ChatCompletionResponseChoice] = []
|
choices: list[ChatCompletionResponseChoice] = []
|
||||||
|
if self.tool_call_id_type == 'kimi_k2':
|
||||||
|
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
|
||||||
|
else:
|
||||||
|
history_tool_call_cnt = 0
|
||||||
|
|
||||||
role = self.get_chat_request_role(request)
|
role = self.get_chat_request_role(request)
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
assert content is not None
|
assert content is not None
|
||||||
tool_calls = TypeAdapter(
|
tool_calls = TypeAdapter(
|
||||||
list[FunctionDefinition]).validate_json(content)
|
list[FunctionDefinition]).validate_json(content)
|
||||||
|
tool_call_ids = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_call_ids.append(
|
||||||
|
make_tool_call_id(id_type=self.tool_call_id_type,
|
||||||
|
func_name=tool_call.name,
|
||||||
|
idx=history_tool_call_cnt))
|
||||||
|
history_tool_call_cnt += 1
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
content="",
|
content="",
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
tool_call_class(function=FunctionCall(
|
tool_call_class(id=tool_call_ids[i],
|
||||||
name=tool_call.name,
|
function=FunctionCall(
|
||||||
arguments=json.dumps(tool_call.parameters,
|
name=tool_call.name,
|
||||||
ensure_ascii=False)))
|
arguments=json.dumps(
|
||||||
for tool_call in tool_calls
|
tool_call.parameters,
|
||||||
])
|
ensure_ascii=False)))
|
||||||
|
for i, tool_call in enumerate(tool_calls)
|
||||||
|
],
|
||||||
|
reasoning_content=reasoning_content)
|
||||||
|
|
||||||
# if the request doesn't use tool choice
|
# if the request doesn't use tool choice
|
||||||
# OR specifies to not use a tool
|
# OR specifies to not use a tool
|
||||||
@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if (tool_call_info.content
|
if (tool_call_info.content
|
||||||
and len(tool_call_info.content) > 0):
|
and len(tool_call_info.content) > 0):
|
||||||
ret_content = tool_call_info.content
|
ret_content = tool_call_info.content
|
||||||
|
|
||||||
message = ChatMessage(role=role,
|
message = ChatMessage(role=role,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
content=ret_content)
|
content=ret_content)
|
||||||
@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
elif choice.message.tool_calls:
|
elif choice.message.tool_calls:
|
||||||
# For tool calls, log the function name and arguments
|
# For tool calls, log the function name and arguments
|
||||||
tool_call_descriptions = []
|
tool_call_descriptions = []
|
||||||
for tool_call in choice.message.tool_calls:
|
for tc in choice.message.tool_calls:
|
||||||
if hasattr(tool_call.function, "name") and hasattr(
|
if hasattr(tc.function, "name") and hasattr(
|
||||||
tool_call.function, "arguments"):
|
tc.function, "arguments"):
|
||||||
tool_call_descriptions.append(
|
tool_call_descriptions.append(
|
||||||
f"{tool_call.function.name}({tool_call.function.arguments})"
|
f"{tc.function.name}({tc.function.arguments})")
|
||||||
)
|
|
||||||
tool_calls_str = ", ".join(tool_call_descriptions)
|
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||||
output_text = f"[tool_calls: {tool_calls_str}]"
|
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import Union
|
|||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser):
|
|||||||
DeltaToolCall(
|
DeltaToolCall(
|
||||||
index=self.current_tool_id,
|
index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True),
|
exclude_none=True),
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import partial_json_parser
|
|||||||
import regex as re
|
import regex as re
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser):
|
|||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import Union
|
|||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser):
|
|||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import partial_json_parser
|
|||||||
import regex as re
|
import regex as re
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
return DeltaMessage(tool_calls=[
|
return DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import Union
|
|||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser):
|
|||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import partial_json_parser
|
|||||||
import regex as re
|
import regex as re
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -222,7 +222,7 @@ class JambaToolParser(ToolParser):
|
|||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import regex as re
|
|||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser):
|
|||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
type="function",
|
type="function",
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
name=function_name).model_dump(
|
name=function_name).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser):
|
|||||||
sent_tools.append({
|
sent_tools.append({
|
||||||
"sent_name": False,
|
"sent_name": False,
|
||||||
"sent_arguments": "",
|
"sent_arguments": "",
|
||||||
"id": random_tool_call_id(),
|
"id": make_tool_call_id(),
|
||||||
})
|
})
|
||||||
|
|
||||||
while len(tool_ids) < tool_count:
|
while len(tool_ids) < tool_count:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import Any, Optional
|
|||||||
import regex as re
|
import regex as re
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
ExtractedToolCallInformation,
|
ExtractedToolCallInformation,
|
||||||
@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
|
|||||||
|
|
||||||
tool_calls: list[ToolCall] = [
|
tool_calls: list[ToolCall] = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=random_tool_call_id(),
|
id=make_tool_call_id(),
|
||||||
type="function",
|
type="function",
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=raw_function_call["name"],
|
name=raw_function_call["name"],
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaFunctionCall, DeltaMessage,
|
DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser):
|
|||||||
function_name = name_match.group(1)
|
function_name = name_match.group(1)
|
||||||
|
|
||||||
# The test expects us to send just the name first
|
# The test expects us to send just the name first
|
||||||
tool_id = random_tool_call_id()
|
tool_id = make_tool_call_id()
|
||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(
|
DeltaToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user