Initialize the delta tool call fields explicitly (#17340)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: igmainc <igmainc@icloud.com>
This commit is contained in:
Maximilien de Bayser 2025-05-12 10:28:58 -03:00 committed by GitHub
parent 7ea6cb28b2
commit 05a4324f8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 51 additions and 34 deletions

View File

@ -32,7 +32,7 @@ class StreamingToolReconstructor:
assert len(delta.tool_calls) < 2, (
"Streaming should include only one tool call per update.")
for call_delta in delta.tool_calls:
assert call_delta.type == "function", (
assert call_delta.type is None or call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}")
current_tool_call = self.tool_calls[

View File

@ -44,6 +44,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -1272,3 +1273,6 @@ def apply_mistral_chat_template(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError from e
def random_tool_call_id() -> str:
return f"chatcmpl-tool-{random_uuid()}"

View File

@ -15,7 +15,8 @@ from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
from typing_extensions import TypeAlias
from vllm import envs
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
random_tool_call_id)
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@ -1339,7 +1340,7 @@ class FunctionCall(OpenAIBaseModel):
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
id: str = Field(default_factory=random_tool_call_id)
type: Literal["function"] = "function"
function: FunctionCall
@ -1351,8 +1352,8 @@ class DeltaFunctionCall(BaseModel):
# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
id: Optional[str] = None
type: Optional[Literal["function"]] = None
index: int
function: Optional[DeltaFunctionCall] = None

View File

@ -16,7 +16,8 @@ from pydantic import TypeAdapter
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage)
ConversationMessage,
random_tool_call_id)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
@ -363,9 +364,10 @@ class OpenAIServingChat(OpenAIServing):
function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
DeltaToolCall(id=random_tool_call_id(),
function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
@ -382,8 +384,7 @@ class OpenAIServingChat(OpenAIServing):
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
index=len(obj) - 1)
])
else:
delta_message = None
@ -422,7 +423,7 @@ class OpenAIServingChat(OpenAIServing):
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
function_name_returned = [False] * num_choices
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
@ -435,7 +436,6 @@ class OpenAIServingChat(OpenAIServing):
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
@ -623,16 +623,27 @@ class OpenAIServingChat(OpenAIServing):
delta_text = previous_text + delta_text
current_text = ""
if function_name_returned[i]:
delta_tool_call = DeltaToolCall(
function=DeltaFunctionCall(
arguments=delta_text),
index=i)
else:
delta_tool_call = DeltaToolCall(
id=random_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
function_name_returned[i] = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
delta_tool_call,
])
elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
@ -835,7 +846,7 @@ class OpenAIServingChat(OpenAIServing):
total_tokens=num_prompt_tokens + completion_tokens,
)
data = chunk.model_dump_json(exclude_unset=True)
data = chunk.model_dump_json(exclude_none=True)
yield f"data: {data}\n\n"
# once the final token is handled, if stream_options.include_usage

View File

@ -9,6 +9,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -22,7 +23,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -200,7 +200,7 @@ class Granite20bFCToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -7,6 +7,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -20,7 +21,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -182,7 +182,7 @@ class GraniteToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -8,6 +8,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -17,7 +18,6 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -259,7 +259,7 @@ class Hermes2ProToolParser(ToolParser):
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -7,6 +7,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -18,7 +19,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -106,7 +106,7 @@ class Internlm2ToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -8,6 +8,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -19,7 +20,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -220,7 +220,7 @@ class JambaToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -10,6 +10,7 @@ import partial_json_parser
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -21,7 +22,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -208,7 +208,7 @@ class Llama3JsonToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -7,6 +7,7 @@ from typing import Any, Optional
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
@ -14,7 +15,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -73,7 +73,7 @@ class Phi4MiniJsonToolParser(ToolParser):
tool_calls: list[ToolCall] = [
ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],

View File

@ -280,6 +280,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
@ -288,5 +289,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id="", index=index, function=DeltaFunctionCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None