diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py
index 8d26b90515901..0bc22e4f1031c 100644
--- a/tests/tool_use/test_xlam_tool_parser.py
+++ b/tests/tool_use/test_xlam_tool_parser.py
@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
+from collections.abc import Generator
+from typing import Optional
import pytest
-from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaMessage, FunctionCall,
+ ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
-from vllm.transformers_utils.tokenizer import get_tokenizer
+from vllm.transformers_utils.detokenizer import detokenize_incrementally
+from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
assert actual_tool_call.function == expected_tool_call.function
+def stream_delta_message_generator(
+ xlam_tool_parser: xLAMToolParser,
+ xlam_tokenizer: AnyTokenizer,
+ model_output: str,
+ request: Optional[ChatCompletionRequest] = None,
+) -> Generator[DeltaMessage, None, None]:
+ all_token_ids = xlam_tokenizer.encode(model_output,
+ add_special_tokens=False)
+
+ previous_text = ""
+ previous_tokens = None
+ prefix_offset = 0
+ read_offset = 0
+ for i, delta_token in enumerate(all_token_ids):
+ delta_token_ids = [delta_token]
+ previous_token_ids = all_token_ids[:i]
+ current_token_ids = all_token_ids[:i + 1]
+
+ (new_tokens, delta_text, new_prefix_offset,
+ new_read_offset) = (detokenize_incrementally(
+ tokenizer=xlam_tokenizer,
+ all_input_ids=current_token_ids,
+ prev_tokens=previous_tokens,
+ prefix_offset=prefix_offset,
+ read_offset=read_offset,
+ skip_special_tokens=False,
+ spaces_between_special_tokens=True,
+ ))
+
+ current_text = previous_text + delta_text
+
+ delta_message = xlam_tool_parser.extract_tool_calls_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ delta_token_ids,
+ request=request,
+ )
+ if delta_message:
+ yield delta_message
+
+ previous_text = current_text
+ previous_tokens = (previous_tokens +
+ new_tokens if previous_tokens else new_tokens)
+ prefix_offset = new_prefix_offset
+ read_offset = new_read_offset
+
+
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
@@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
+ "single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
],
"I'll check the weather for you.",
),
+ (
+ """I'll help you check the weather.[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ ))
+ ],
+ "I'll help you check the weather.",
+ ),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
@@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
+
+
+@pytest.mark.parametrize(
+ ids=[
+ "parallel_tool_calls",
+ "single_tool_with_think_tag",
+ "single_tool_with_json_code_block",
+ "single_tool_with_tool_calls_tag",
+ "single_tool_with_tool_call_xml_tags",
+ ],
+ argnames=["model_output", "expected_tool_calls", "expected_content"],
+ argvalues=[
+ (
+ """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ )),
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Orlando",
+ "state": "FL",
+ "unit": "fahrenheit",
+ }),
+ )),
+ ],
+ "",
+ ),
+ (
+ """I'll help you with that.[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ ))
+ ],
+ "I'll help you with that.",
+ ),
+ (
+ """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ ))
+ ],
+ "",
+ ),
+ (
+ """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ ))
+ ],
+ "",
+ ),
+ (
+ """I can help with that.[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
+ [
+ ToolCall(function=FunctionCall(
+ name="get_current_weather",
+ arguments=json.dumps({
+ "city": "Dallas",
+ "state": "TX",
+ "unit": "fahrenheit",
+ }),
+ ))
+ ],
+ "I can help with that.",
+ ),
+ ],
+)
+def test_extract_tool_calls_streaming_incremental(
+ xlam_tool_parser,
+ xlam_tokenizer,
+ model_output,
+ expected_tool_calls,
+ expected_content,
+):
+ """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
+ request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
+
+ chunks = []
+ for delta_message in stream_delta_message_generator(
+ xlam_tool_parser, xlam_tokenizer, model_output, request):
+ chunks.append(delta_message)
+
+ # Should have multiple chunks
+ assert len(chunks) >= 3
+
+ # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
+ header_found = False
+ expected_first_tool = expected_tool_calls[0]
+ for chunk in chunks:
+ if chunk.tool_calls and chunk.tool_calls[0].id:
+ header_found = True
+ assert (chunk.tool_calls[0].function.name ==
+ expected_first_tool.function.name)
+ assert chunk.tool_calls[0].type == "function"
+ # Arguments may be empty initially or None
+ if chunk.tool_calls[0].function.arguments is not None:
+ # If present, should be empty string initially
+ assert chunk.tool_calls[0].function.arguments == ""
+ break
+ assert header_found
+
+ # Should have chunks with incremental arguments
+ arg_chunks = []
+ for chunk in chunks:
+ if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
+ and chunk.tool_calls[0].function.arguments != ""
+ and chunk.tool_calls[0].index ==
+ 0 # Only collect arguments from the first tool call
+ ):
+ arg_chunks.append(chunk.tool_calls[0].function.arguments)
+
+ # Arguments should be streamed incrementally
+ assert len(arg_chunks) > 1
+
+ # Concatenated arguments should form valid JSON for the first tool call
+ full_args = "".join(arg_chunks)
+ parsed_args = json.loads(full_args)
+ expected_args = json.loads(expected_first_tool.function.arguments)
+ assert parsed_args == expected_args
diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
index 87cd413b37200..484e904cd8c36 100644
--- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
@@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
"""
Extract tool calls for streaming mode.
"""
- # Simplify detection: if it begins with "[" treat it as a function call
- is_function_call = (current_text.strip().startswith("["))
+ # First, check for a definitive start of a tool call block.
+ # This prevents premature parsing of incomplete output.
+ stripped_text = current_text.strip()
+ preprocessed_content, preprocessed_tool_calls = (
+ self.preprocess_model_output(current_text))
- # If not a function call, return normal content
- if not is_function_call:
+ # For JSON code blocks, we need to detect them earlier, even if incomplete
+ has_potential_json_block = ("```json" in current_text
+ or "```\n[" in current_text
+ or "[TOOL_CALLS]" in current_text
+ or "" in current_text)
+
+ is_tool_call_block = (
+ stripped_text.startswith("[")
+ or stripped_text.startswith("")
+ or stripped_text.startswith("[TOOL_CALLS]") or
+ # Check if we have thinking tags with JSON-like content following
+ ("[" in current_text) or
+ # Check if the text contains a JSON array after preprocessing
+ preprocessed_tool_calls is not None or
+ # For JSON code blocks, detect early if we see enough structure
+ (has_potential_json_block and '"name"' in current_text
+ and '"arguments"' in current_text))
+
+ if not is_tool_call_block:
return DeltaMessage(content=delta_text)
try:
@@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
# Try parsing as JSON to check for complete tool calls
try:
- parsed_tools = json.loads(current_text)
+ # Use preprocessed tool calls if available
+ tool_calls_text = (preprocessed_tool_calls if
+ preprocessed_tool_calls else current_text)
+ parsed_tools = json.loads(tool_calls_text)
if isinstance(parsed_tools, list):
# Update our tool array for next time
self.prev_tool_call_arr = parsed_tools
@@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser):
return delta
# Use regex to identify tool calls in the output
+ # Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
+ search_text = (preprocessed_tool_calls
+ if preprocessed_tool_calls else current_text)
+
+ # For JSON code blocks that aren't complete yet, try to extract the JSON content
+ if not preprocessed_tool_calls and has_potential_json_block:
+ # Try to extract the JSON array from within the code block
+ json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
+ current_text)
+ if json_match:
+ potential_json = json_match.group(1).strip()
+ # Use this as search text even if it's incomplete
+ if potential_json.startswith("[") and (
+ '"name"' in potential_json
+ and '"arguments"' in potential_json):
+ search_text = potential_json
+
+ # Try to find complete tool names first
name_pattern = r'"name"\s*:\s*"([^"]+)"'
- name_matches = list(re.finditer(name_pattern, current_text))
+ name_matches = list(re.finditer(name_pattern, search_text))
tool_count = len(name_matches)
- # If no tools found yet, return
+ # If no complete tool names found, check for partial tool names
if tool_count == 0:
- return None
+ # Check if we're in the middle of parsing a tool name
+ partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
+ partial_matches = list(
+ re.finditer(partial_name_pattern, search_text))
+ if partial_matches:
+ # We have a partial tool name - not ready to emit yet
+ return None
+ else:
+ # No tools found at all
+ return None
# Ensure our state arrays are large enough
while len(self.streaming_state["sent_tools"]) < tool_count:
@@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
# First, check for the empty arguments case: "arguments": {}
empty_args_pattern = (
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
- empty_args_match = re.search(empty_args_pattern, current_text)
+ empty_args_match = re.search(empty_args_pattern, search_text)
# Check if this tool has empty arguments
if empty_args_match and empty_args_match.start() > 0:
@@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
# Extract arguments for current tool using regex for non-empty arguments
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
- args_matches = list(re.finditer(args_pattern, current_text))
+ args_matches = list(re.finditer(args_pattern, search_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
@@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
# Handle transition between tools
is_last_tool = current_idx == tool_count - 1
- # Find where the arguments for our current tool end
- if not is_last_tool:
- # If we have more tools after this one, try to find the complete argument block
- next_tool_pos = current_text.find(
- "},{", args_matches[current_idx].start())
- if next_tool_pos != -1:
- args_end_pos = (next_tool_pos + 1
- ) # +1 to include the '}'
- args_text = (current_text[args_matches[current_idx]
- .start():args_end_pos].
- split('"arguments":')[1].strip())
+ # For multiple tools, extract only the arguments for the current tool
+ if tool_count > 1:
+ # Parse the entire JSON structure to properly extract arguments for each tool
+ try:
+ parsed_tools = json.loads(search_text)
+ if isinstance(
+ parsed_tools,
+ list) and current_idx < len(parsed_tools):
+ current_tool = parsed_tools[current_idx]
+ if isinstance(current_tool.get("arguments"),
+ dict):
+ args_text = json.dumps(
+ current_tool["arguments"])
+ else:
+ args_text = str(
+ current_tool.get("arguments", "{}"))
+ except (json.JSONDecodeError, KeyError, IndexError):
+ # Fallback to regex-based extraction
+ pass
# If arguments haven't been sent yet
sent_args = self.streaming_state["sent_tools"][
@@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser):
index=current_idx,
function=DeltaFunctionCall(
arguments="{").model_dump(
- exclude_none=True), # type: ignore
+ exclude_none=True), # type: ignore
)
])
return delta