[Bugfix] Add support for <tool_call> format in streaming mode for XLAM Tool Parser (#22769)

Signed-off-by: Devon Peroutky <devon@kindo.ai>
This commit is contained in:
Code Jesus 2025-08-31 23:07:54 -07:00 committed by GitHub
parent 1cb39dbcdd
commit 422e793fa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 296 additions and 24 deletions

View File

@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
assert actual_tool_call.function == expected_tool_call.function
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = xlam_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = (detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
))
current_text = previous_text + delta_text
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
],
"I'll check the weather for you.",
),
(
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll help you check the weather.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
"",
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"<think>I'll help you with that.</think>",
),
(
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I can help with that.",
),
],
)
def test_extract_tool_calls_streaming_incremental(
xlam_tool_parser,
xlam_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
chunks = []
for delta_message in stream_delta_message_generator(
xlam_tool_parser, xlam_tokenizer, model_output, request):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) >= 3
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
header_found = False
expected_first_tool = expected_tool_calls[0]
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name ==
expected_first_tool.function.name)
assert chunk.tool_calls[0].type == "function"
# Arguments may be empty initially or None
if chunk.tool_calls[0].function.arguments is not None:
# If present, should be empty string initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index ==
0 # Only collect arguments from the first tool call
):
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON for the first tool call
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
expected_args = json.loads(expected_first_tool.function.arguments)
assert parsed_args == expected_args

View File

@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
"""
Extract tool calls for streaming mode.
"""
# Simplify detection: if it begins with "[" treat it as a function call
is_function_call = (current_text.strip().startswith("["))
# First, check for a definitive start of a tool call block.
# This prevents premature parsing of incomplete output.
stripped_text = current_text.strip()
preprocessed_content, preprocessed_tool_calls = (
self.preprocess_model_output(current_text))
# If not a function call, return normal content
if not is_function_call:
# For JSON code blocks, we need to detect them earlier, even if incomplete
has_potential_json_block = ("```json" in current_text
or "```\n[" in current_text
or "[TOOL_CALLS]" in current_text
or "<tool_call>" in current_text)
is_tool_call_block = (
stripped_text.startswith("[")
or stripped_text.startswith("<tool_call>")
or stripped_text.startswith("[TOOL_CALLS]") or
# Check if we have thinking tags with JSON-like content following
("</think>[" in current_text) or
# Check if the text contains a JSON array after preprocessing
preprocessed_tool_calls is not None or
# For JSON code blocks, detect early if we see enough structure
(has_potential_json_block and '"name"' in current_text
and '"arguments"' in current_text))
if not is_tool_call_block:
return DeltaMessage(content=delta_text)
try:
@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
# Try parsing as JSON to check for complete tool calls
try:
parsed_tools = json.loads(current_text)
# Use preprocessed tool calls if available
tool_calls_text = (preprocessed_tool_calls if
preprocessed_tool_calls else current_text)
parsed_tools = json.loads(tool_calls_text)
if isinstance(parsed_tools, list):
# Update our tool array for next time
self.prev_tool_call_arr = parsed_tools
@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser):
return delta
# Use regex to identify tool calls in the output
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
search_text = (preprocessed_tool_calls
if preprocessed_tool_calls else current_text)
# For JSON code blocks that aren't complete yet, try to extract the JSON content
if not preprocessed_tool_calls and has_potential_json_block:
# Try to extract the JSON array from within the code block
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
current_text)
if json_match:
potential_json = json_match.group(1).strip()
# Use this as search text even if it's incomplete
if potential_json.startswith("[") and (
'"name"' in potential_json
and '"arguments"' in potential_json):
search_text = potential_json
# Try to find complete tool names first
name_pattern = r'"name"\s*:\s*"([^"]+)"'
name_matches = list(re.finditer(name_pattern, current_text))
name_matches = list(re.finditer(name_pattern, search_text))
tool_count = len(name_matches)
# If no tools found yet, return
# If no complete tool names found, check for partial tool names
if tool_count == 0:
return None
# Check if we're in the middle of parsing a tool name
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
partial_matches = list(
re.finditer(partial_name_pattern, search_text))
if partial_matches:
# We have a partial tool name - not ready to emit yet
return None
else:
# No tools found at all
return None
# Ensure our state arrays are large enough
while len(self.streaming_state["sent_tools"]) < tool_count:
@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
# First, check for the empty arguments case: "arguments": {}
empty_args_pattern = (
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
empty_args_match = re.search(empty_args_pattern, current_text)
empty_args_match = re.search(empty_args_pattern, search_text)
# Check if this tool has empty arguments
if empty_args_match and empty_args_match.start() > 0:
@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
# Extract arguments for current tool using regex for non-empty arguments
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
args_matches = list(re.finditer(args_pattern, current_text))
args_matches = list(re.finditer(args_pattern, search_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
# Handle transition between tools
is_last_tool = current_idx == tool_count - 1
# Find where the arguments for our current tool end
if not is_last_tool:
# If we have more tools after this one, try to find the complete argument block
next_tool_pos = current_text.find(
"},{", args_matches[current_idx].start())
if next_tool_pos != -1:
args_end_pos = (next_tool_pos + 1
) # +1 to include the '}'
args_text = (current_text[args_matches[current_idx]
.start():args_end_pos].
split('"arguments":')[1].strip())
# For multiple tools, extract only the arguments for the current tool
if tool_count > 1:
# Parse the entire JSON structure to properly extract arguments for each tool
try:
parsed_tools = json.loads(search_text)
if isinstance(
parsed_tools,
list) and current_idx < len(parsed_tools):
current_tool = parsed_tools[current_idx]
if isinstance(current_tool.get("arguments"),
dict):
args_text = json.dumps(
current_tool["arguments"])
else:
args_text = str(
current_tool.get("arguments", "{}"))
except (json.JSONDecodeError, KeyError, IndexError):
# Fallback to regex-based extraction
pass
# If arguments haven't been sent yet
sent_args = self.streaming_state["sent_tools"][
@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser):
index=current_idx,
function=DeltaFunctionCall(
arguments="{").model_dump(
exclude_none=True), # type: ignore
exclude_none=True), # type: ignore
)
])
return delta