From 422e793fa6c4381809ca236946ab5c2206ad59c1 Mon Sep 17 00:00:00 2001 From: Code Jesus Date: Sun, 31 Aug 2025 23:07:54 -0700 Subject: [PATCH] [Bugfix] Add support for `` format in streaming mode for XLAM Tool Parser (#22769) Signed-off-by: Devon Peroutky --- tests/tool_use/test_xlam_tool_parser.py | 218 +++++++++++++++++- .../openai/tool_parsers/xlam_tool_parser.py | 102 ++++++-- 2 files changed, 296 insertions(+), 24 deletions(-) diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 8d26b90515901..0bc22e4f1031c 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -2,12 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +from collections.abc import Generator +from typing import Optional import pytest -from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, FunctionCall, + ToolCall) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r" @@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], assert actual_tool_call.function == expected_tool_call.function +def stream_delta_message_generator( + xlam_tool_parser: xLAMToolParser, + xlam_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None, +) -> Generator[DeltaMessage, None, None]: + all_token_ids = xlam_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = (detokenize_incrementally( + tokenizer=xlam_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + )) + + current_text = previous_text + delta_text + + delta_message = xlam_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + def test_extract_tool_calls_no_tools(xlam_tool_parser): model_output = "This is a test" extracted_tool_calls = xlam_tool_parser.extract_tool_calls( @@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): "single_tool_with_think_tag", "single_tool_with_json_code_block", "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): ], "I'll check the weather for you.", ), + ( + """I'll help you check the weather.[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I'll help you check the weather.", + ), ], ) def test_extract_tool_calls(xlam_tool_parser, model_output, @@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser): assert hasattr(result, "tool_calls") assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "get_current_weather" + + +@pytest.mark.parametrize( + ids=[ + "parallel_tool_calls", + "single_tool_with_think_tag", + "single_tool_with_json_code_block", + "single_tool_with_tool_calls_tag", + "single_tool_with_tool_call_xml_tags", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + }), + )), + ], + "", + ), + ( + """I'll help you with that.[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I'll help you with that.", + ), + ( + """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "", + ), + ( + """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "", + ), + ( + """I can help with that.[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + "I can help with that.", + ), + ], +) +def test_extract_tool_calls_streaming_incremental( + xlam_tool_parser, + xlam_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 + request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) + + chunks = [] + for delta_message in stream_delta_message_generator( + xlam_tool_parser, xlam_tokenizer, model_output, request): + chunks.append(delta_message) + + # Should have multiple chunks + assert len(chunks) >= 3 + + # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501 + header_found = False + expected_first_tool = expected_tool_calls[0] + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].id: + header_found = True + assert (chunk.tool_calls[0].function.name == + expected_first_tool.function.name) + assert chunk.tool_calls[0].type == "function" + # Arguments may be empty initially or None + if chunk.tool_calls[0].function.arguments is not None: + # If present, should be empty string initially + assert chunk.tool_calls[0].function.arguments == "" + break + assert header_found + + # Should have chunks with incremental arguments + arg_chunks = [] + for chunk in chunks: + if (chunk.tool_calls and chunk.tool_calls[0].function.arguments + and chunk.tool_calls[0].function.arguments != "" + and chunk.tool_calls[0].index == + 0 # Only collect arguments from the first tool call + ): + arg_chunks.append(chunk.tool_calls[0].function.arguments) + + # Arguments should be streamed incrementally + assert len(arg_chunks) > 1 + + # Concatenated arguments should form valid JSON for the first tool call + full_args = "".join(arg_chunks) + parsed_args = json.loads(full_args) + expected_args = json.loads(expected_first_tool.function.arguments) + assert parsed_args == expected_args diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 87cd413b37200..484e904cd8c36 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser): """ Extract tool calls for streaming mode. """ - # Simplify detection: if it begins with "[" treat it as a function call - is_function_call = (current_text.strip().startswith("[")) + # First, check for a definitive start of a tool call block. + # This prevents premature parsing of incomplete output. + stripped_text = current_text.strip() + preprocessed_content, preprocessed_tool_calls = ( + self.preprocess_model_output(current_text)) - # If not a function call, return normal content - if not is_function_call: + # For JSON code blocks, we need to detect them earlier, even if incomplete + has_potential_json_block = ("```json" in current_text + or "```\n[" in current_text + or "[TOOL_CALLS]" in current_text + or "" in current_text) + + is_tool_call_block = ( + stripped_text.startswith("[") + or stripped_text.startswith("") + or stripped_text.startswith("[TOOL_CALLS]") or + # Check if we have thinking tags with JSON-like content following + ("[" in current_text) or + # Check if the text contains a JSON array after preprocessing + preprocessed_tool_calls is not None or + # For JSON code blocks, detect early if we see enough structure + (has_potential_json_block and '"name"' in current_text + and '"arguments"' in current_text)) + + if not is_tool_call_block: return DeltaMessage(content=delta_text) try: @@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser): # Try parsing as JSON to check for complete tool calls try: - parsed_tools = json.loads(current_text) + # Use preprocessed tool calls if available + tool_calls_text = (preprocessed_tool_calls if + preprocessed_tool_calls else current_text) + parsed_tools = json.loads(tool_calls_text) if isinstance(parsed_tools, list): # Update our tool array for next time self.prev_tool_call_arr = parsed_tools @@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser): return delta # Use regex to identify tool calls in the output + # Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks + search_text = (preprocessed_tool_calls + if preprocessed_tool_calls else current_text) + + # For JSON code blocks that aren't complete yet, try to extract the JSON content + if not preprocessed_tool_calls and has_potential_json_block: + # Try to extract the JSON array from within the code block + json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)", + current_text) + if json_match: + potential_json = json_match.group(1).strip() + # Use this as search text even if it's incomplete + if potential_json.startswith("[") and ( + '"name"' in potential_json + and '"arguments"' in potential_json): + search_text = potential_json + + # Try to find complete tool names first name_pattern = r'"name"\s*:\s*"([^"]+)"' - name_matches = list(re.finditer(name_pattern, current_text)) + name_matches = list(re.finditer(name_pattern, search_text)) tool_count = len(name_matches) - # If no tools found yet, return + # If no complete tool names found, check for partial tool names if tool_count == 0: - return None + # Check if we're in the middle of parsing a tool name + partial_name_pattern = r'"name"\s*:\s*"([^"]*)' + partial_matches = list( + re.finditer(partial_name_pattern, search_text)) + if partial_matches: + # We have a partial tool name - not ready to emit yet + return None + else: + # No tools found at all + return None # Ensure our state arrays are large enough while len(self.streaming_state["sent_tools"]) < tool_count: @@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser): # First, check for the empty arguments case: "arguments": {} empty_args_pattern = ( r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') - empty_args_match = re.search(empty_args_pattern, current_text) + empty_args_match = re.search(empty_args_pattern, search_text) # Check if this tool has empty arguments if empty_args_match and empty_args_match.start() > 0: @@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser): # Extract arguments for current tool using regex for non-empty arguments args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' - args_matches = list(re.finditer(args_pattern, current_text)) + args_matches = list(re.finditer(args_pattern, search_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) @@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser): # Handle transition between tools is_last_tool = current_idx == tool_count - 1 - # Find where the arguments for our current tool end - if not is_last_tool: - # If we have more tools after this one, try to find the complete argument block - next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) - if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1 - ) # +1 to include the '}' - args_text = (current_text[args_matches[current_idx] - .start():args_end_pos]. - split('"arguments":')[1].strip()) + # For multiple tools, extract only the arguments for the current tool + if tool_count > 1: + # Parse the entire JSON structure to properly extract arguments for each tool + try: + parsed_tools = json.loads(search_text) + if isinstance( + parsed_tools, + list) and current_idx < len(parsed_tools): + current_tool = parsed_tools[current_idx] + if isinstance(current_tool.get("arguments"), + dict): + args_text = json.dumps( + current_tool["arguments"]) + else: + args_text = str( + current_tool.get("arguments", "{}")) + except (json.JSONDecodeError, KeyError, IndexError): + # Fallback to regex-based extraction + pass # If arguments haven't been sent yet sent_args = self.streaming_state["sent_tools"][ @@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser): index=current_idx, function=DeltaFunctionCall( arguments="{").model_dump( - exclude_none=True), # type: ignore + exclude_none=True), # type: ignore ) ]) return delta