mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:24:54 +08:00
Initialize the delta tool call fields explicitly (#17340)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: igmainc <igmainc@icloud.com>
This commit is contained in:
parent
7ea6cb28b2
commit
05a4324f8e
@ -32,7 +32,7 @@ class StreamingToolReconstructor:
|
||||
assert len(delta.tool_calls) < 2, (
|
||||
"Streaming should include only one tool call per update.")
|
||||
for call_delta in delta.tool_calls:
|
||||
assert call_delta.type == "function", (
|
||||
assert call_delta.type is None or call_delta.type == "function", (
|
||||
"Streaming tool calls should only emit function calls. Got "
|
||||
f"{call_delta.type}")
|
||||
current_tool_call = self.tool_calls[
|
||||
|
||||
@ -44,6 +44,7 @@ from vllm.transformers_utils.chat_templates import (
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1272,3 +1273,6 @@ def apply_mistral_chat_template(
|
||||
"An error occurred in `mistral_common` while applying chat "
|
||||
"template")
|
||||
raise ValueError from e
|
||||
|
||||
def random_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{random_uuid()}"
|
||||
@ -15,7 +15,8 @@ from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
random_tool_call_id)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
@ -1339,7 +1340,7 @@ class FunctionCall(OpenAIBaseModel):
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
id: str = Field(default_factory=random_tool_call_id)
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
@ -1351,8 +1352,8 @@ class DeltaFunctionCall(BaseModel):
|
||||
|
||||
# a tool call delta where everything is optional
|
||||
class DeltaToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
id: Optional[str] = None
|
||||
type: Optional[Literal["function"]] = None
|
||||
index: int
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
@ -16,7 +16,8 @@ from pydantic import TypeAdapter
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
ConversationMessage)
|
||||
ConversationMessage,
|
||||
random_tool_call_id)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
@ -363,9 +364,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
function_name_returned = True
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=current_tool_call["name"],
|
||||
arguments=arguments),
|
||||
DeltaToolCall(id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=current_tool_call["name"],
|
||||
arguments=arguments),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
@ -382,8 +384,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# instead of name every time
|
||||
name=None,
|
||||
arguments=delta_text),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
index=len(obj) - 1)
|
||||
])
|
||||
else:
|
||||
delta_message = None
|
||||
@ -422,7 +423,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned: Optional[list[bool]] = None
|
||||
function_name_returned = [False] * num_choices
|
||||
|
||||
# Only one of these will be used, thus previous_texts and
|
||||
# all_previous_token_ids will not be used twice in the same iteration.
|
||||
@ -435,7 +436,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
elif request.tool_choice == "required":
|
||||
previous_texts = [""] * num_choices
|
||||
function_name_returned = [False] * num_choices
|
||||
all_previous_token_ids = None
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
@ -623,16 +623,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_text = previous_text + delta_text
|
||||
current_text = ""
|
||||
|
||||
if function_name_returned[i]:
|
||||
delta_tool_call = DeltaToolCall(
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
else:
|
||||
delta_tool_call = DeltaToolCall(
|
||||
id=random_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
function_name_returned[i] = True
|
||||
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
delta_tool_call,
|
||||
])
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
assert function_name_returned is not None
|
||||
previous_text = previous_texts[i]
|
||||
current_text = previous_text + delta_text
|
||||
fn_name_returned = function_name_returned[i]
|
||||
@ -835,7 +846,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
data = chunk.model_dump_json(exclude_none=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# once the final token is handled, if stream_options.include_usage
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -22,7 +23,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -200,7 +200,7 @@ class Granite20bFCToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -20,7 +21,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -182,7 +182,7 @@ class GraniteToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -17,7 +18,6 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -259,7 +259,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -18,7 +19,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -106,7 +106,7 @@ class Internlm2ToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -19,7 +20,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -220,7 +220,7 @@ class JambaToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
||||
@ -10,6 +10,7 @@ import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -21,7 +22,6 @@ from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
||||
is_complete_json,
|
||||
partial_json_loads)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -208,7 +208,7 @@ class Llama3JsonToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
@ -14,7 +15,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -73,7 +73,7 @@ class Phi4MiniJsonToolParser(ToolParser):
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
id=random_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
|
||||
@ -280,6 +280,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
new_call_args = new_call_args[:-len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
@ -288,5 +289,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args):]
|
||||
return DeltaToolCall(
|
||||
id="", index=index, function=DeltaFunctionCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(
|
||||
arguments=arg_diff)) if arg_diff else None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user