mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 02:55:01 +08:00
1317 lines
52 KiB
Python
1317 lines
52 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import ast
|
|
import json
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
from xml.parsers.expat import ParserCreate
|
|
|
|
import regex as re
|
|
|
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionToolsParam,
|
|
DeltaFunctionCall,
|
|
DeltaMessage,
|
|
DeltaToolCall,
|
|
ExtractedToolCallInformation,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
|
ToolParser,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.tokenizers import TokenizerLike
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class StreamingXMLToolCallParser:
|
|
"""
|
|
Simplified streaming XML tool call parser
|
|
Supports streaming input, parsing, and output
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.reset_streaming_state()
|
|
|
|
# Tool configuration information
|
|
self.tools: list[ChatCompletionToolsParam] | None = None
|
|
self.tool_call_start_token: str = "<tool_call>"
|
|
self.tool_call_end_token: str = "</tool_call>"
|
|
self.function_start_token: str = "<function="
|
|
self.function_end_token: str = "</function>"
|
|
self.parameter_start_token: str = "<parameter="
|
|
self.parameter_end_token: str = "</parameter>"
|
|
|
|
def reset_streaming_state(self):
|
|
"""Reset streaming parsing state"""
|
|
|
|
self.deltas = []
|
|
# state for streaming
|
|
self.tool_call_index = 0
|
|
self.current_call_id = None
|
|
self.last_completed_call_id = None
|
|
self.current_function_name = None
|
|
self.current_function_open = False
|
|
self.parameters = {}
|
|
self.current_param_name = None
|
|
self.current_param_value = ""
|
|
self.current_param_value_converted = ""
|
|
self.current_param_is_first = False
|
|
self.should_emit_end_newline = False
|
|
self.start_quote_emitted = False
|
|
|
|
self.streaming_buffer = ""
|
|
self.last_processed_pos = 0
|
|
|
|
self.text_content_buffer = ""
|
|
|
|
# state for preprocessing and deferred parsing
|
|
self._pre_inside_parameter = False
|
|
self._pre_param_buffer = ""
|
|
self._pre_current_param_name = None
|
|
self.defer_current_parameter = False
|
|
self.deferred_param_raw_value = ""
|
|
|
|
# recreate parser
|
|
self.parser = ParserCreate()
|
|
self.setup_parser()
|
|
|
|
def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage:
|
|
"""
|
|
Parse single streaming XML chunk and return Delta response
|
|
This is the actual streaming interface that receives chunks
|
|
one by one and maintains internal state
|
|
|
|
Args:
|
|
xml_chunk: Single XML chunk string
|
|
Returns:
|
|
DeltaMessage: Contains delta information generated by this chunk,
|
|
returns empty response if no complete elements
|
|
"""
|
|
# Record delta count before processing
|
|
initial_delta_count = len(self.deltas)
|
|
|
|
self.streaming_buffer += xml_chunk
|
|
|
|
found_elements = self._process_complete_xml_elements()
|
|
|
|
if found_elements:
|
|
# If complete elements found, check if end events were missed
|
|
# some tags may not have been triggered
|
|
try:
|
|
new_deltas = self.deltas[initial_delta_count:]
|
|
# If this chunk contains </function>
|
|
# but didn't generate '}', then complete it
|
|
if (
|
|
self.current_call_id is not None
|
|
and self.function_end_token in xml_chunk
|
|
):
|
|
# - Added '}' (non-empty parameter ending)
|
|
# - Added '{}' (empty parameter function)
|
|
has_function_close = any(
|
|
(
|
|
td.tool_calls
|
|
and any(
|
|
(
|
|
tc.function
|
|
and tc.id == self.current_call_id
|
|
and isinstance(tc.function.arguments, str)
|
|
and (tc.function.arguments in ("}", "{}"))
|
|
)
|
|
for tc in td.tool_calls
|
|
)
|
|
)
|
|
for td in new_deltas
|
|
)
|
|
if not has_function_close:
|
|
# Close potentially unclosed element
|
|
if self.current_param_name:
|
|
self._end_element("parameter")
|
|
if self.current_function_name:
|
|
self._end_element("function")
|
|
# If this chunk contains </tool_call>
|
|
# but didn't generate final empty delta, then complete it
|
|
if (
|
|
self.current_call_id is not None
|
|
and self.tool_call_end_token in xml_chunk
|
|
):
|
|
has_toolcall_close = any(
|
|
(
|
|
td.tool_calls
|
|
and any(
|
|
(
|
|
tc.type == "function"
|
|
and tc.function
|
|
and tc.function.arguments == ""
|
|
and tc.id == self.current_call_id
|
|
)
|
|
for tc in td.tool_calls
|
|
)
|
|
)
|
|
for td in new_deltas
|
|
)
|
|
if not has_toolcall_close:
|
|
# Close potentially unclosed element
|
|
if self.current_param_name:
|
|
self._end_element("parameter")
|
|
if self.current_function_name:
|
|
self._end_element("function")
|
|
self._end_element("tool_call")
|
|
except Exception as e:
|
|
logger.warning("Error with fallback parsing: %s", e)
|
|
# Merge newly generated deltas into single response
|
|
result_delta = self._merge_new_deltas_to_single_response(
|
|
initial_delta_count
|
|
)
|
|
return result_delta
|
|
else:
|
|
# No complete elements, check if there's unoutput text content
|
|
if self.text_content_buffer and self.tool_call_index == 0:
|
|
# Has text content but no tool_call yet, output text content
|
|
text_delta = DeltaMessage(content=self.text_content_buffer)
|
|
self._emit_delta(text_delta)
|
|
# Clear buffer to avoid duplicate output
|
|
self.text_content_buffer = ""
|
|
return text_delta
|
|
|
|
# If this chunk contains end tags but wasn't triggered by parser,
|
|
# manually complete end events
|
|
# Only execute when still on the same call as when entered,
|
|
# to prevent accidentally closing new calls
|
|
# in multi <tool_call> scenarios
|
|
if self.current_call_id is not None and (
|
|
self.function_end_token in xml_chunk
|
|
or self.tool_call_end_token in xml_chunk
|
|
):
|
|
# Close potentially unclosed element
|
|
if self.current_param_name:
|
|
self._end_element("parameter")
|
|
if self.function_end_token in xml_chunk and self.current_function_name:
|
|
self._end_element("function")
|
|
if self.tool_call_end_token in xml_chunk:
|
|
self._end_element("tool_call")
|
|
# Return the merged delta result generated by this fallback
|
|
result_delta = self._merge_new_deltas_to_single_response(
|
|
initial_delta_count
|
|
)
|
|
return result_delta
|
|
|
|
# No complete elements, return empty response
|
|
return DeltaMessage(content=None)
|
|
|
|
def _escape_xml_special_chars(self, text: str) -> str:
|
|
"""
|
|
Escape XML special characters
|
|
Args:
|
|
text: Original text
|
|
Returns:
|
|
Escaped text
|
|
"""
|
|
xml_escapes = {
|
|
"&": "&",
|
|
"<": "<",
|
|
">": ">",
|
|
'"': """,
|
|
"'": "'",
|
|
}
|
|
|
|
for char, escape in xml_escapes.items():
|
|
text = text.replace(char, escape)
|
|
|
|
return text
|
|
|
|
def _process_complete_xml_elements(self) -> bool:
|
|
"""
|
|
Process complete XML elements in buffer
|
|
|
|
Returns:
|
|
bool: Whether complete elements were found and processed
|
|
"""
|
|
found_any = False
|
|
|
|
while self.last_processed_pos < len(self.streaming_buffer):
|
|
# Find next complete xml element
|
|
element, end_pos = self._find_next_complete_element(self.last_processed_pos)
|
|
if element is None:
|
|
# No complete element found, wait for more data
|
|
break
|
|
|
|
# Check if this element should be skipped
|
|
if self._should_skip_element(element):
|
|
self.last_processed_pos = end_pos
|
|
continue
|
|
|
|
# Found complete XML element, process it
|
|
try:
|
|
preprocessed_element = self._preprocess_xml_chunk(element)
|
|
# Check if this is the first tool_call start
|
|
if (
|
|
(
|
|
preprocessed_element.strip().startswith("<tool_call>")
|
|
or preprocessed_element.strip().startswith("<function name=")
|
|
)
|
|
and self.tool_call_index == 0
|
|
) and self.text_content_buffer:
|
|
# First tool_call starts,
|
|
# output previously collected text content first
|
|
text_delta = DeltaMessage(content=self.text_content_buffer)
|
|
self._emit_delta(text_delta)
|
|
# Clear buffer for potential subsequent text content
|
|
self.text_content_buffer = ""
|
|
|
|
# If a new tool_call starts and
|
|
# there are already completed tool_calls
|
|
if (
|
|
preprocessed_element.strip().startswith("<tool_call>")
|
|
and self.tool_call_index > 0
|
|
and self.current_call_id
|
|
):
|
|
# Reset parser state but preserve generated deltas
|
|
if self.current_param_name:
|
|
self._end_element("parameter")
|
|
if self.current_function_open or self.current_function_name:
|
|
self._end_element("function")
|
|
# Output final tool_call tail delta
|
|
final_delta = DeltaMessage(
|
|
role=None,
|
|
content=None,
|
|
reasoning=None,
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments=""),
|
|
)
|
|
],
|
|
)
|
|
self._emit_delta(final_delta)
|
|
# Reset XML parser and current call state
|
|
self._reset_xml_parser_after_tool_call()
|
|
# Parse preprocessed element
|
|
self.parser.Parse(preprocessed_element, False)
|
|
found_any = True
|
|
|
|
except Exception as e:
|
|
logger.warning("Error when parsing XML elements: %s", e)
|
|
|
|
# Update processed position
|
|
self.last_processed_pos = end_pos
|
|
|
|
return found_any
|
|
|
|
def _should_skip_element(self, element: str) -> bool:
|
|
"""
|
|
Determine whether an element should be skipped
|
|
|
|
Args:
|
|
element: Element to evaluate
|
|
|
|
Returns:
|
|
bool: True means should skip, False means should process
|
|
"""
|
|
|
|
# If it's a tool_call XML tag, don't skip
|
|
if (
|
|
element.startswith(self.tool_call_start_token)
|
|
or element.startswith(self.function_start_token)
|
|
or element.startswith(self.parameter_start_token)
|
|
):
|
|
return False
|
|
|
|
# If currently not parsing tool calls and not blank,
|
|
# collect this text instead of skipping
|
|
# Only process other XML elements after tool_call appears,
|
|
# otherwise treat as plain text
|
|
if self.current_call_id is None and element:
|
|
# Collect text content to buffer
|
|
self.text_content_buffer += element
|
|
return True # Still skip, but content has been collected
|
|
|
|
# If currently parsing tool calls,
|
|
# this might be parameter value, don't skip
|
|
if self.current_call_id is not None:
|
|
return False
|
|
|
|
# Skip blank content
|
|
return not element
|
|
|
|
def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
|
|
"""
|
|
Find next complete XML element from specified position
|
|
|
|
Args:
|
|
start_pos: Position to start searching
|
|
|
|
Returns:
|
|
(Complete element string, element end position),
|
|
returns (None, start_pos) if no complete element found
|
|
"""
|
|
buffer = self.streaming_buffer[start_pos:]
|
|
|
|
if not buffer:
|
|
return None, start_pos
|
|
|
|
if buffer.startswith("<"):
|
|
# Need to ensure no new < appears,
|
|
# find the nearest one between < and >
|
|
tag_end = buffer.find("<", 1)
|
|
tag_end2 = buffer.find(">", 1)
|
|
if tag_end != -1 and tag_end2 != -1:
|
|
# Next nearest is <
|
|
if tag_end < tag_end2:
|
|
return buffer[:tag_end], start_pos + tag_end
|
|
# Next nearest is >, means found XML element
|
|
else:
|
|
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
|
|
elif tag_end != -1:
|
|
return buffer[:tag_end], start_pos + tag_end
|
|
elif tag_end2 != -1:
|
|
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
|
|
else:
|
|
# If currently not parsing tool calls (entering a tool_call),
|
|
# check if starts with <tool_call> or <function=
|
|
if self.current_call_id is None:
|
|
# Check if might be start of <tool_call>
|
|
if buffer == "<tool_call>"[: len(buffer)]:
|
|
# Might be start of <tool_call>, wait for more data
|
|
return None, start_pos
|
|
elif (
|
|
buffer.startswith("<function=")
|
|
or buffer == "<function="[: len(buffer)]
|
|
):
|
|
# Might be start of <function=, wait for more data
|
|
# to get the complete function tag
|
|
return None, start_pos
|
|
else:
|
|
# Not start of <tool_call> or <function=, treat as text
|
|
return buffer, start_pos + len(buffer)
|
|
else:
|
|
# When parsing tool calls,
|
|
# wait for more data to get complete tag
|
|
return None, start_pos
|
|
else:
|
|
# Find text content (until next < or buffer end)
|
|
next_tag_pos = buffer.find("<")
|
|
if next_tag_pos != -1:
|
|
# Found text content
|
|
text_content = buffer[:next_tag_pos]
|
|
return text_content, start_pos + next_tag_pos
|
|
else:
|
|
# Buffer end is all text, process
|
|
# (no longer wait for more data)
|
|
remaining = buffer
|
|
return remaining, start_pos + len(remaining)
|
|
|
|
def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage:
|
|
"""
|
|
Merge newly generated deltas from this processing
|
|
into a single DeltaMessage
|
|
|
|
Args:
|
|
initial_count: Delta count before processing
|
|
|
|
Returns:
|
|
Merged DeltaMessage containing all newly generated delta information
|
|
"""
|
|
if len(self.deltas) <= initial_count:
|
|
return DeltaMessage(content=None)
|
|
|
|
# Get newly generated deltas
|
|
new_deltas = self.deltas[initial_count:]
|
|
|
|
if len(new_deltas) == 1:
|
|
# Only one new delta, return directly
|
|
return new_deltas[0]
|
|
|
|
# Merge multiple new deltas
|
|
merged_tool_calls: list[DeltaToolCall] = []
|
|
merged_content: str = ""
|
|
|
|
for delta in new_deltas:
|
|
if delta.content:
|
|
merged_content += delta.content
|
|
if delta.tool_calls:
|
|
# For tool_calls, we need to intelligently merge arguments
|
|
for tool_call in delta.tool_calls:
|
|
# Find if there's already a tool_call with the same call_id
|
|
existing_call = None
|
|
for existing in merged_tool_calls:
|
|
if existing.id == tool_call.id:
|
|
existing_call = existing
|
|
break
|
|
|
|
if existing_call and existing_call.function:
|
|
# Merge to existing tool_call
|
|
if tool_call.function and tool_call.function.name:
|
|
existing_call.function.name = tool_call.function.name
|
|
if (
|
|
tool_call.function
|
|
and tool_call.function.arguments is not None
|
|
):
|
|
if existing_call.function.arguments is None:
|
|
existing_call.function.arguments = ""
|
|
|
|
# For streaming JSON parameters,
|
|
# simply concatenate in order
|
|
new_args = tool_call.function.arguments
|
|
existing_call.function.arguments += new_args
|
|
if tool_call.type:
|
|
existing_call.type = tool_call.type
|
|
else:
|
|
# Add new tool_call
|
|
merged_tool_calls.append(tool_call)
|
|
|
|
return DeltaMessage(
|
|
content=merged_content if merged_content else None,
|
|
tool_calls=merged_tool_calls,
|
|
)
|
|
|
|
def _preprocess_xml_chunk(self, chunk: str) -> str:
|
|
"""
|
|
Preprocess XML chunk, handle non-standard formats,
|
|
and escape special characters
|
|
|
|
Args:
|
|
chunk: Original XML chunk
|
|
|
|
Returns:
|
|
Processed XML chunk
|
|
"""
|
|
|
|
# Check if this is a tool_call related element
|
|
is_tool_call = False
|
|
if chunk.startswith(self.tool_call_start_token) or chunk.startswith(
|
|
self.tool_call_end_token
|
|
):
|
|
is_tool_call = True
|
|
if chunk.startswith(self.function_start_token) or chunk.startswith(
|
|
self.function_end_token
|
|
):
|
|
is_tool_call = True
|
|
if chunk.startswith(self.parameter_start_token) or chunk.startswith(
|
|
self.parameter_end_token
|
|
):
|
|
is_tool_call = True
|
|
# Handle <function=name> format -> <function name="name">
|
|
processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk)
|
|
# Handle <parameter=name> format -> <parameter name="name">
|
|
processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed)
|
|
|
|
original_chunk = chunk
|
|
# If in parameter value accumulation mode
|
|
if self._pre_inside_parameter:
|
|
# Parameter end: output accumulated raw text
|
|
# safely then return </parameter>
|
|
if processed.startswith("</parameter>"):
|
|
body_text = self._pre_param_buffer
|
|
# Trigger deferred parsing mode
|
|
# literal_eval+json output in end_element
|
|
self.defer_current_parameter = True
|
|
self.deferred_param_raw_value = body_text
|
|
# Clean up state
|
|
self._pre_inside_parameter = False
|
|
self._pre_param_buffer = ""
|
|
self._pre_current_param_name = None
|
|
safe_text = self._escape_xml_special_chars(body_text)
|
|
return f"{safe_text}</parameter>"
|
|
else:
|
|
# If this is the first block of content after entering parameter
|
|
# evaluate if deferred parsing is needed;
|
|
# If not needed, exit accumulation mode
|
|
# and pass through directly
|
|
if self._pre_param_buffer == "":
|
|
# Get current parameter type
|
|
param_type = (
|
|
self._get_param_type(self._pre_current_param_name)
|
|
if self._pre_current_param_name
|
|
else "string"
|
|
)
|
|
# Only these types need deferred parsing to
|
|
# handle Python literals containing single quotes
|
|
is_object_type = param_type in ["object"]
|
|
is_complex_type = (
|
|
param_type in ["array", "arr", "sequence"]
|
|
or param_type.startswith("dict")
|
|
or param_type.startswith("list")
|
|
)
|
|
|
|
# Only delay when contains container symbols
|
|
# and has single quotes and is complex type
|
|
has_container_hint = (
|
|
("[" in original_chunk)
|
|
or ("{" in original_chunk)
|
|
or ("(" in original_chunk)
|
|
)
|
|
|
|
# Determine if deferred parsing is needed
|
|
need_defer = False
|
|
if is_complex_type:
|
|
# Complex type, always need deferred parsing
|
|
need_defer = True
|
|
elif (
|
|
is_object_type
|
|
and has_container_hint
|
|
and ("'" in original_chunk)
|
|
):
|
|
# Object type with container symbols
|
|
# and single quotes, need deferred parsing
|
|
need_defer = True
|
|
|
|
if not need_defer:
|
|
# No need for deferred parsing,
|
|
# exit parameter mode directly
|
|
self._pre_inside_parameter = False
|
|
return self._escape_xml_special_chars(original_chunk)
|
|
self._pre_param_buffer += original_chunk
|
|
return ""
|
|
|
|
# Parameter start: enable accumulation
|
|
if processed.startswith("<parameter name="):
|
|
m = re.match(r'<parameter name="([^"]+)">', processed)
|
|
if m:
|
|
self._pre_current_param_name = m.group(1)
|
|
self._pre_inside_parameter = True
|
|
self._pre_param_buffer = ""
|
|
return processed
|
|
|
|
# If processed doesn't contain special_token, escape processed
|
|
# This is because XML parsing encounters special characters
|
|
# and reports errors, so escaping is needed
|
|
if not is_tool_call:
|
|
processed = self._escape_xml_special_chars(processed)
|
|
return processed
|
|
|
|
def _emit_delta(self, delta: DeltaMessage):
|
|
"""Emit Delta response (streaming output)"""
|
|
self.deltas.append(delta)
|
|
|
|
def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
|
|
"""Before starting to process new elements,
|
|
if there are unclosed tags from before,
|
|
automatically complete their endings to the parser.
|
|
- If there are unclosed parameters,
|
|
it's equivalent to feeding `</parameter>`
|
|
- When about to start a new function or tool_call,
|
|
if there are unclosed functions, complete `</function>`.
|
|
- When about to start a new tool_call,
|
|
if there are unclosed tool_calls, complete `</tool_call>`.
|
|
"""
|
|
# First close unclosed parameters
|
|
if self.current_param_name:
|
|
self._end_element("parameter")
|
|
|
|
# If about to start new function or tool_call,
|
|
# and there are unclosed functions, close function first
|
|
if incoming_tag in ("function", "tool_call") and self.current_function_name:
|
|
self._end_element("function")
|
|
|
|
# If about to start new tool_call,
|
|
# and there are unclosed tool_calls, close tool_call first
|
|
if incoming_tag == "tool_call" and self.current_call_id:
|
|
self._end_element("tool_call")
|
|
|
|
def _start_element(self, name: str, attrs: dict[str, str]):
|
|
"""Handle XML start element events"""
|
|
|
|
if name == "root":
|
|
return
|
|
|
|
if name == "tool_call":
|
|
# Before opening new tool_call,
|
|
# automatically complete previous unclosed tags
|
|
self._auto_close_open_parameter_if_needed("tool_call")
|
|
|
|
self.parameters = {}
|
|
self.current_call_id = make_tool_call_id()
|
|
self.current_param_is_first = True
|
|
self.tool_call_index += 1
|
|
elif name.startswith("function") or (name == "function"):
|
|
# If missing tool_call, manually complete
|
|
if not self.current_call_id:
|
|
self._start_element("tool_call", {})
|
|
# Before opening new function,
|
|
# automatically complete previous unclosed tags (parameter/function)
|
|
self._auto_close_open_parameter_if_needed("function")
|
|
function_name = self._extract_function_name(name, attrs)
|
|
self.current_function_name = function_name
|
|
self.current_function_open = True
|
|
if function_name:
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(
|
|
name=function_name, arguments=""
|
|
),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
elif name.startswith("parameter") or (name == "parameter"):
|
|
# If previous parameter hasn't ended normally,
|
|
# complete its end first, then start new parameter
|
|
self._auto_close_open_parameter_if_needed("parameter")
|
|
param_name = self._extract_parameter_name(name, attrs)
|
|
self.current_param_name = param_name
|
|
self.current_param_value = ""
|
|
self.current_param_value_converted = ""
|
|
self.start_quote_emitted = False # Reset start quote flag
|
|
|
|
# Only output parameter name and colon,
|
|
# don't output quotes
|
|
# decide after parameter value type is determined
|
|
if param_name:
|
|
if not self.parameters:
|
|
# First parameter
|
|
# start JSON, only output parameter name and colon
|
|
json_start = f'{{"{param_name}": '
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(
|
|
name=None, arguments=json_start
|
|
),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
self.current_param_is_first = True
|
|
else:
|
|
# Subsequent parameters
|
|
# add comma and parameter name, no quotes
|
|
json_continue = f', "{param_name}": '
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(
|
|
name=None, arguments=json_continue
|
|
),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
self.current_param_is_first = False
|
|
|
|
def _char_data(self, data: str):
|
|
"""Handle XML character data events"""
|
|
if data and self.current_param_name:
|
|
# If preprocessing stage determines deferred parsing is needed,
|
|
# only cache character data, no streaming output
|
|
if self.defer_current_parameter:
|
|
original_data = data
|
|
if self.should_emit_end_newline:
|
|
original_data = "\n" + original_data
|
|
self.should_emit_end_newline = False
|
|
if original_data.endswith("\n"):
|
|
self.should_emit_end_newline = True
|
|
original_data = original_data[:-1]
|
|
self.current_param_value += original_data
|
|
return
|
|
|
|
param_type = self._get_param_type(self.current_param_name)
|
|
|
|
# Check if this is the first time receiving data for this parameter
|
|
# If this is the first packet of data and starts with \n, remove \n
|
|
if not self.current_param_value and data.startswith("\n"):
|
|
data = data[1:]
|
|
|
|
# Output start quote for string type (if not already output)
|
|
if (
|
|
param_type in ["string", "str", "text", "varchar", "char", "enum"]
|
|
and not self.start_quote_emitted
|
|
):
|
|
quote_delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments='"'),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(quote_delta)
|
|
self.start_quote_emitted = True
|
|
|
|
if not data:
|
|
return
|
|
|
|
original_data = data
|
|
# Delay output of trailing newline
|
|
if self.should_emit_end_newline:
|
|
original_data = "\n" + original_data
|
|
self.should_emit_end_newline = False
|
|
if original_data.endswith("\n"):
|
|
self.should_emit_end_newline = True
|
|
original_data = original_data[:-1]
|
|
self.current_param_value += original_data
|
|
|
|
# convert parameter value by param_type
|
|
converted_value = self._convert_param_value(
|
|
self.current_param_value, param_type
|
|
)
|
|
output_data = self._convert_for_json_streaming(converted_value, param_type)
|
|
|
|
delta_data = output_data[len(self.current_param_value_converted) :]
|
|
self.current_param_value_converted = output_data
|
|
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments=delta_data),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
|
|
def _end_element(self, name: str):
|
|
"""Handle XML end element events"""
|
|
|
|
if name == "root":
|
|
return
|
|
|
|
# If function or tool_call ends and there are still unclosed parameters,
|
|
# complete parameter end first
|
|
if (
|
|
name.startswith("function") or name == "function" or name == "tool_call"
|
|
) and self.current_param_name:
|
|
self._auto_close_open_parameter_if_needed()
|
|
|
|
if (
|
|
name.startswith("parameter") or name == "parameter"
|
|
) and self.current_param_name:
|
|
# End current parameter
|
|
param_name = self.current_param_name
|
|
param_value = self.current_param_value
|
|
|
|
# If in deferred parsing mode,
|
|
# perform overall parsing on raw content
|
|
# accumulated in preprocessing stage and output once
|
|
if self.defer_current_parameter:
|
|
raw_text = (
|
|
self.deferred_param_raw_value
|
|
if self.deferred_param_raw_value
|
|
else param_value
|
|
)
|
|
parsed_value = None
|
|
output_arguments = None
|
|
try:
|
|
# If previously delayed trailing newline,
|
|
# add it back before parsing
|
|
if self.should_emit_end_newline:
|
|
raw_for_parse = raw_text + "\n"
|
|
else:
|
|
raw_for_parse = raw_text
|
|
parsed_value = ast.literal_eval(raw_for_parse)
|
|
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
|
|
except Exception:
|
|
# Fallback: output as string as-is
|
|
output_arguments = json.dumps(raw_text, ensure_ascii=False)
|
|
parsed_value = raw_text
|
|
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(
|
|
name=None, arguments=output_arguments
|
|
),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
|
|
# Clean up and store
|
|
self.should_emit_end_newline = False
|
|
self.parameters[param_name] = parsed_value
|
|
self.current_param_name = None
|
|
self.current_param_value = ""
|
|
self.current_param_value_converted = ""
|
|
self.start_quote_emitted = False
|
|
self.defer_current_parameter = False
|
|
self.deferred_param_raw_value = ""
|
|
return
|
|
|
|
param_type = self._get_param_type(param_name)
|
|
|
|
# convert complete parameter value by param_type
|
|
converted_value = self._convert_param_value(param_value, param_type)
|
|
|
|
# Decide whether to add end quote based on parameter type
|
|
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
|
# For empty string parameters, need special handling
|
|
if not param_value and not self.start_quote_emitted:
|
|
# No start quote output,
|
|
# directly output complete empty string
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments='""'),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
else:
|
|
# Non-empty parameter value, output end quote
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments='"'),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
|
|
self.should_emit_end_newline = False
|
|
# Store converted value
|
|
self.parameters[param_name] = converted_value
|
|
self.current_param_name = None
|
|
self.current_param_value = ""
|
|
self.current_param_value_converted = ""
|
|
self.start_quote_emitted = False
|
|
|
|
elif name.startswith("function") or name == "function":
|
|
# if there are parameters, close JSON object
|
|
if self.parameters:
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments="}"),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
# return empty object
|
|
else:
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments="{}"),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
self.current_function_open = False
|
|
|
|
elif name == "tool_call":
|
|
# Before ending tool_call,
|
|
# ensure function is closed to complete missing right brace
|
|
if self.current_function_open:
|
|
# If there are still unclosed parameters, close them first
|
|
if self.current_param_name:
|
|
self._end_element("parameter")
|
|
# Close function, ensure output '}' or '{}'
|
|
self._end_element("function")
|
|
# Final Delta
|
|
delta = DeltaMessage(
|
|
tool_calls=[
|
|
DeltaToolCall(
|
|
index=self.tool_call_index - 1,
|
|
id=self.current_call_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(name=None, arguments=""),
|
|
)
|
|
]
|
|
)
|
|
self._emit_delta(delta)
|
|
|
|
# Check if there's text content to output (between tool_calls)
|
|
if self.text_content_buffer.strip():
|
|
text_delta = DeltaMessage(content=self.text_content_buffer)
|
|
self._emit_delta(text_delta)
|
|
|
|
self._reset_xml_parser_after_tool_call()
|
|
|
|
def setup_parser(self):
|
|
"""Set up XML parser event handlers"""
|
|
self.parser.buffer_text = True
|
|
self.parser.StartElementHandler = self._start_element
|
|
self.parser.EndElementHandler = self._end_element
|
|
self.parser.CharacterDataHandler = self._char_data
|
|
|
|
def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
|
|
"""Set tool configuration information"""
|
|
self.tools = tools
|
|
|
|
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
|
|
"""Extract function name from various formats"""
|
|
if attrs and "name" in attrs:
|
|
return attrs["name"]
|
|
|
|
if "=" in name:
|
|
parts = name.split("=", 1)
|
|
if len(parts) == 2 and parts[0] == "function":
|
|
return parts[1]
|
|
|
|
return None
|
|
|
|
def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None:
|
|
"""Extract parameter name from various formats"""
|
|
if attrs and "name" in attrs:
|
|
return attrs["name"]
|
|
|
|
if "=" in name:
|
|
parts = name.split("=", 1)
|
|
if len(parts) == 2 and parts[0] == "parameter":
|
|
return parts[1]
|
|
|
|
return None
|
|
|
|
def _get_param_type(self, param_name: str) -> str:
|
|
"""Get parameter type based on tool configuration, defaults to string
|
|
Args:
|
|
param_name: Parameter name
|
|
|
|
Returns:
|
|
Parameter type
|
|
"""
|
|
if not self.tools or not self.current_function_name:
|
|
return "string"
|
|
|
|
for tool in self.tools:
|
|
if not hasattr(tool, "type") or not (
|
|
hasattr(tool, "function") and hasattr(tool.function, "name")
|
|
):
|
|
continue
|
|
if (
|
|
tool.type == "function"
|
|
and tool.function.name == self.current_function_name
|
|
):
|
|
if not hasattr(tool.function, "parameters"):
|
|
return "string"
|
|
params = tool.function.parameters
|
|
if isinstance(params, dict) and "properties" in params:
|
|
properties = params["properties"]
|
|
if param_name in properties and isinstance(
|
|
properties[param_name], dict
|
|
):
|
|
return self.repair_param_type(
|
|
str(properties[param_name].get("type", "string"))
|
|
)
|
|
elif isinstance(params, dict) and param_name in params:
|
|
param_config = params[param_name]
|
|
if isinstance(param_config, dict):
|
|
return self.repair_param_type(
|
|
str(param_config.get("type", "string"))
|
|
)
|
|
break
|
|
return "string"
|
|
|
|
def repair_param_type(self, param_type: str) -> str:
|
|
"""Repair unknown parameter types by treating them as string
|
|
Args:
|
|
param_type: Parameter type
|
|
|
|
Returns:
|
|
Repaired parameter type
|
|
"""
|
|
if (
|
|
param_type in ["string", "str", "text", "varchar", "char", "enum"]
|
|
or param_type.startswith("int")
|
|
or param_type.startswith("uint")
|
|
or param_type.startswith("long")
|
|
or param_type.startswith("short")
|
|
or param_type.startswith("unsigned")
|
|
or param_type.startswith("num")
|
|
or param_type.startswith("float")
|
|
or param_type in ["boolean", "bool", "binary"]
|
|
or (
|
|
param_type in ["object", "array", "arr", "sequence"]
|
|
or param_type.startswith("dict")
|
|
or param_type.startswith("list")
|
|
)
|
|
):
|
|
return param_type
|
|
else:
|
|
return "string"
|
|
|
|
def _convert_param_value(self, param_value: str, param_type: str) -> Any:
|
|
"""Convert value based on parameter type
|
|
Args:
|
|
param_value: Parameter value
|
|
param_type: Parameter type
|
|
|
|
Returns:
|
|
Converted value
|
|
"""
|
|
if param_value.lower() == "null":
|
|
return None
|
|
|
|
param_type = param_type.strip().lower()
|
|
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
|
return param_value
|
|
elif (
|
|
param_type.startswith("int")
|
|
or param_type.startswith("uint")
|
|
or param_type.startswith("long")
|
|
or param_type.startswith("short")
|
|
or param_type.startswith("unsigned")
|
|
):
|
|
try:
|
|
return int(param_value)
|
|
except (ValueError, TypeError):
|
|
logger.warning(
|
|
"Parsed value '%s' of parameter '%s' is not an integer "
|
|
"in tool '%s', degenerating to string.",
|
|
param_value,
|
|
)
|
|
return param_value
|
|
elif param_type.startswith("num") or param_type.startswith("float"):
|
|
try:
|
|
float_param_value: float = float(param_value)
|
|
return (
|
|
float_param_value
|
|
if float_param_value - int(float_param_value) != 0
|
|
else int(float_param_value)
|
|
)
|
|
except (ValueError, TypeError):
|
|
logger.warning(
|
|
"Parsed value '%s' of parameter '%s' is not a float "
|
|
"in tool '%s', degenerating to string.",
|
|
param_value,
|
|
)
|
|
return param_value
|
|
elif param_type in ["boolean", "bool", "binary"]:
|
|
param_value = param_value.lower()
|
|
return param_value == "true"
|
|
else:
|
|
return param_value
|
|
|
|
def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str:
|
|
"""Convert converted_value based on
|
|
whether it's empty and if type is string
|
|
Args:
|
|
converted_value: Converted value
|
|
param_type: Parameter type
|
|
|
|
Returns:
|
|
Converted string for streaming output
|
|
"""
|
|
# Check if value is empty, but exclude numeric 0
|
|
if converted_value is None or converted_value == "":
|
|
return ""
|
|
|
|
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
|
|
# String type, remove double quotes
|
|
return json.dumps(converted_value, ensure_ascii=False)[1:-1]
|
|
else:
|
|
# Non-string type, return complete JSON string
|
|
if not isinstance(converted_value, str):
|
|
return json.dumps(converted_value, ensure_ascii=False)
|
|
else:
|
|
return converted_value
|
|
|
|
def _reset_xml_parser_after_tool_call(self):
|
|
"""
|
|
Each tool_call is treated as a separate XML document,
|
|
so we need to reset the parser after each tool_call.
|
|
"""
|
|
|
|
# recreate XML parser
|
|
self.parser = ParserCreate()
|
|
self.setup_parser()
|
|
|
|
# Reset current tool_call state
|
|
if self.current_call_id:
|
|
self.last_completed_call_id = self.current_call_id
|
|
self.current_call_id = None
|
|
self.current_function_name = None
|
|
self.current_function_open = False
|
|
self.parameters = {}
|
|
self.current_param_name = None
|
|
self.current_param_value = ""
|
|
self.current_param_value_converted = ""
|
|
self.current_param_is_first = False
|
|
self.should_emit_end_newline = False
|
|
self.start_quote_emitted = False
|
|
self.text_content_buffer = ""
|
|
|
|
# Reset preprocessing and deferred parsing state
|
|
self._pre_inside_parameter = False
|
|
self._pre_param_buffer = ""
|
|
self._pre_current_param_name = None
|
|
self.defer_current_parameter = False
|
|
self.deferred_param_raw_value = ""
|
|
|
|
|
|
class Qwen3XMLToolParser(ToolParser):
|
|
def __init__(self, tokenizer: TokenizerLike):
|
|
super().__init__(tokenizer)
|
|
self.parser = StreamingXMLToolCallParser()
|
|
|
|
# Add missing attributes for compatibility with serving_chat.py
|
|
self.prev_tool_call_arr: list[dict] = []
|
|
self.streamed_args_for_tool: list[str] = []
|
|
|
|
logger.info(
|
|
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
|
)
|
|
|
|
def extract_tool_calls(
|
|
self,
|
|
model_output: str,
|
|
request: ChatCompletionRequest,
|
|
) -> ExtractedToolCallInformation:
|
|
self.parser.reset_streaming_state()
|
|
# Reset tool call tracking arrays for new extraction
|
|
self.prev_tool_call_arr = []
|
|
self.streamed_args_for_tool = []
|
|
if request:
|
|
self.parser.set_tools(request.tools)
|
|
result = self.parser.parse_single_streaming_chunks(model_output)
|
|
if not result.tool_calls:
|
|
return ExtractedToolCallInformation(
|
|
tool_calls=[],
|
|
tools_called=False,
|
|
content=result.content,
|
|
)
|
|
else:
|
|
tool_calls = []
|
|
for tool_call in result.tool_calls:
|
|
if tool_call.function and tool_call.function.name:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=tool_call.id,
|
|
type=tool_call.type,
|
|
function=FunctionCall(
|
|
name=tool_call.function.name,
|
|
arguments=tool_call.function.arguments,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Update tool call tracking arrays for compatibility
|
|
tool_index = (
|
|
tool_call.index
|
|
if tool_call.index is not None
|
|
else len(self.prev_tool_call_arr) - 1
|
|
)
|
|
|
|
# Ensure we have enough entries in our tracking arrays
|
|
while len(self.prev_tool_call_arr) <= tool_index:
|
|
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
|
|
while len(self.streamed_args_for_tool) <= tool_index:
|
|
self.streamed_args_for_tool.append("")
|
|
|
|
# Update tool call information
|
|
self.prev_tool_call_arr[tool_index]["name"] = (
|
|
tool_call.function.name
|
|
)
|
|
self.prev_tool_call_arr[tool_index]["arguments"] = (
|
|
tool_call.function.arguments
|
|
)
|
|
|
|
# Update streamed arguments
|
|
if tool_call.function.arguments:
|
|
self.streamed_args_for_tool[tool_index] = (
|
|
tool_call.function.arguments
|
|
)
|
|
|
|
return ExtractedToolCallInformation(
|
|
tool_calls=tool_calls,
|
|
tools_called=len(tool_calls) > 0,
|
|
content=result.content,
|
|
)
|
|
|
|
def extract_tool_calls_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int],
|
|
current_token_ids: Sequence[int],
|
|
delta_token_ids: Sequence[int],
|
|
request: ChatCompletionRequest,
|
|
) -> DeltaMessage | None:
|
|
if not previous_text:
|
|
self.parser.reset_streaming_state()
|
|
# Reset tool call tracking arrays for new streaming session
|
|
self.prev_tool_call_arr = []
|
|
self.streamed_args_for_tool = []
|
|
if request:
|
|
self.parser.set_tools(request.tools)
|
|
|
|
# Model sometimes outputs separately causing delta_text to be empty.
|
|
# If there were tool_calls before and all current tool_calls have ended,
|
|
# return an empty tool_call for outer streaming output
|
|
# to correctly output tool_call field
|
|
if not delta_text and delta_token_ids:
|
|
open_calls = current_text.count(
|
|
self.parser.tool_call_start_token
|
|
) - current_text.count(self.parser.tool_call_end_token)
|
|
if (
|
|
open_calls == 0
|
|
and self.parser.tool_call_index > 0
|
|
or not self.parser.tool_call_index
|
|
and current_text
|
|
):
|
|
return DeltaMessage(content="")
|
|
return None
|
|
|
|
# Parse the delta text and get the result
|
|
result = self.parser.parse_single_streaming_chunks(delta_text)
|
|
|
|
# Update tool call tracking arrays based on incremental parsing results
|
|
if result and result.tool_calls:
|
|
for tool_call in result.tool_calls:
|
|
if tool_call.function:
|
|
tool_index = (
|
|
tool_call.index
|
|
if tool_call.index is not None
|
|
else len(self.prev_tool_call_arr) - 1
|
|
)
|
|
|
|
# Ensure we have enough entries in our tracking arrays
|
|
while len(self.prev_tool_call_arr) <= tool_index:
|
|
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
|
|
while len(self.streamed_args_for_tool) <= tool_index:
|
|
self.streamed_args_for_tool.append("")
|
|
|
|
# Update tool name if provided
|
|
if tool_call.function.name:
|
|
self.prev_tool_call_arr[tool_index]["name"] = (
|
|
tool_call.function.name
|
|
)
|
|
|
|
# Update arguments incrementally
|
|
if tool_call.function.arguments is not None:
|
|
# Concatenate the incremental arguments
|
|
# to the existing streamed arguments
|
|
self.prev_tool_call_arr[tool_index]["arguments"] += (
|
|
tool_call.function.arguments
|
|
)
|
|
self.streamed_args_for_tool[tool_index] += (
|
|
tool_call.function.arguments
|
|
)
|
|
return result
|