diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py new file mode 100644 index 0000000000000..40c3158e9e683 --- /dev/null +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -0,0 +1,618 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Generator +from typing import Optional + +import pytest + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( + Qwen3CoderToolParser) +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" + + +@pytest.fixture(scope="module") +def qwen3_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def qwen3_tool_parser(qwen3_tokenizer): + return Qwen3CoderToolParser(qwen3_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam(type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + }, + "state": { + "type": "string", + "description": + "The state code" + }, + "unit": { + "type": "string", + "enum": + ["fahrenheit", "celsius"] + } + }, + "required": ["city", "state"] + } + }), + ChatCompletionToolsParam(type="function", + function={ + "name": "calculate_area", + "description": + "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": { + "type": "string" + }, + "dimensions": { + "type": "object" + }, + "precision": { + "type": "integer" + } + } + } + }) + ] + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + # Qwen3 parser doesn't generate IDs during extraction + assert actual_tool_call.type == "function" + assert ( + actual_tool_call.function.name == expected_tool_call.function.name) + assert (json.loads(actual_tool_call.function.arguments) == json.loads( + expected_tool_call.function.arguments)) + + +def stream_delta_message_generator( + qwen3_tool_parser: Qwen3CoderToolParser, + qwen3_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None +) -> Generator[DeltaMessage, None, None]: + all_token_ids = qwen3_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=qwen3_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = qwen3_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(qwen3_tool_parser): + model_output = "This is a test response without any tool calls" + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool", + "single_tool_with_content", + "single_tool_multiline_param", + "parallel_tools", + "tool_with_typed_params", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + (''' + + +Dallas + + +TX + + +fahrenheit + + +''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], None), + ('''Sure! Let me check the weather for you. + + +Dallas + + +TX + + +fahrenheit + + +''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], "Sure! Let me check the weather for you."), + (''' + + +rectangle + + +{"width": 10, + "height": 20} + + +2 + + +''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "rectangle", + "dimensions": { + "width": 10, + "height": 20 + }, + "precision": 2 + }))) + ], None), + (''' + + +Dallas + + +TX + + +fahrenheit + + + + + + +Orlando + + +FL + + +fahrenheit + + +''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))), + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit" + }))) + ], None), + ('''Let me calculate that area for you. + + +circle + + +{"radius": 15.5} + + +3 + + +''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "circle", + "dimensions": { + "radius": 15.5 + }, + "precision": 3 + }))) + ], "Let me calculate that area for you."), + ], +) +def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, + expected_tool_calls, expected_content): + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools): + """Test fallback parsing when XML tags are missing""" + model_output = ''' + +Dallas + + +TX + +''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert (extracted_tool_calls.tool_calls[0].function.name == + "get_current_weather") + + +def test_extract_tool_calls_type_conversion(qwen3_tool_parser): + """Test parameter type conversion based on tool schema""" + tools = [ + ChatCompletionToolsParam(type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": { + "type": "integer" + }, + "float_param": { + "type": "float" + }, + "bool_param": { + "type": "boolean" + }, + "str_param": { + "type": "string" + }, + "obj_param": { + "type": "object" + } + } + } + }) + ] + + model_output = ''' + + +42 + + +3.14 + + +true + + +hello world + + +{"key": "value"} + + +''' + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args["int_param"] == 42 + assert args["float_param"] == 3.14 + assert args["bool_param"] is True + assert args["str_param"] == "hello world" + assert args["obj_param"] == {"key": "value"} + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool", + "single_tool_with_content", + "parallel_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("This is a test without tools", [], "This is a test without tools"), + (''' + + +Dallas + + +TX + + +fahrenheit + + +''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], ""), + ('''Sure! Let me check the weather for you. + + +Dallas + + +TX + + +fahrenheit + + +''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], "Sure! Let me check the weather for you."), + (''' + + +Dallas + + +TX + + +fahrenheit + + + + + + +Orlando + + +FL + + +celsius + + +''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))), + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "celsius" + }))) + ], ""), + ], +) +def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, + sample_tools, model_output, + expected_tool_calls, expected_content): + """Test incremental streaming behavior""" + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} # Track state per tool index + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + # First chunk should have id, name, and type + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + # Should only be set once + assert tool_states[idx]["name"] is None + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + # Accumulate arguments incrementally + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify final content + assert other_content == expected_content + + # Verify we got all expected tool calls + assert len(tool_states) == len(expected_tool_calls) + + # Verify each tool call + for idx, expected_tool in enumerate(expected_tool_calls): + state = tool_states[idx] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == expected_tool.function.name + + # Parse accumulated arguments + arguments_str = state["arguments"] + assert arguments_str is not None + actual_args = json.loads(arguments_str) + expected_args = json.loads(expected_tool.function.arguments) + assert actual_args == expected_args + + +def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, + qwen3_tokenizer, + sample_tools): + """Test that streaming is truly incremental""" + model_output = '''I'll check the weather. + + +Dallas + + +TX + + +''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + chunks = [] + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + chunks.append(delta_message) + + # Should have multiple chunks + assert len(chunks) > 3 + + # First chunk(s) should be content + assert chunks[0].content is not None + assert chunks[0].tool_calls is None or chunks[0].tool_calls == [] + + # Should have a chunk with tool header (id, name, type) + header_found = False + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].id: + header_found = True + assert (chunk.tool_calls[0].function.name == "get_current_weather") + assert chunk.tool_calls[0].type == "function" + # Empty initially + assert chunk.tool_calls[0].function.arguments == "" + break + assert header_found + + # Should have chunks with incremental arguments + arg_chunks = [] + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].function.arguments: + arg_chunks.append(chunk.tool_calls[0].function.arguments) + + # Arguments should be streamed incrementally + assert len(arg_chunks) > 1 + + # Concatenated arguments should form valid JSON + full_args = "".join(arg_chunks) + parsed_args = json.loads(full_args) + assert parsed_args["city"] == "Dallas" + assert parsed_args["state"] == "TX" diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 9eda7155f01f3..88c8aa929b78d 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -17,6 +17,7 @@ from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser +from .qwen3coder_tool_parser import Qwen3CoderToolParser from .xlam_tool_parser import xLAMToolParser __all__ = [ @@ -38,4 +39,5 @@ __all__ = [ "KimiK2ToolParser", "HunyuanA13BToolParser", "Glm4MoeModelToolParser", + "Qwen3CoderToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py new file mode 100644 index 0000000000000..cf4d0b231aee1 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -0,0 +1,669 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["qwen3_coder"]) +class Qwen3CoderToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + + # Sentinel tokens for streaming mode + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_prefix: str = "(.*?)", re.DOTALL) + self.tool_call_regex = re.compile( + r"(.*?)|(.*?)$", re.DOTALL) + self.tool_call_function_regex = re.compile( + r"|| str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_string_id = None + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + def get_arguments_config(func_name: str) -> dict: + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") + and hasattr(config.function, "name")): + continue + if (config.type == "function" + and config.function.name == func_name): + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def convert_param_value(param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + converted_value: Any + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if (isinstance(param_config[param_name], dict) + and "type" in param_config[param_name]): + param_type = str( + param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + converted_value = int(param_value) + return converted_value + except ValueError: + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + elif (param_type.startswith("num") + or param_type.startswith("float")): + try: + float_param_value = float(param_value) + converted_value = (float_param_value if float_param_value - + int(float_param_value) != 0 else + int(float_param_value)) + return converted_value + except ValueError: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a " + "boolean (`true` of `false`) in tool '%s', " + "degenerating to false.", param_value, param_name, + func_name) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + converted_value = json.loads(param_value) + return converted_value + except json.JSONDecodeError: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a " + "valid JSON object in tool '%s', will try other " + "methods to parse it.", param_value, param_name, + func_name) + try: + converted_value = eval(param_value) + return converted_value + except Exception: + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `eval()` in tool '%s', " + "degenerating to string.", param_value, param_name, + func_name) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value( + param_value, param_name, param_config, function_name) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + match[0] if match[0] else match[1] for match in matched_ranges + ] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend( + self.tool_call_function_regex.findall(tool_call)) + + function_calls = [ + match[0] if match[0] else match[1] for match in raw_function_calls + ] + return function_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_calls = self._get_function_calls(model_output) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set + # finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + content_index = model_output.find(self.tool_call_start_token) + content_index = (content_index if content_index >= 0 else + model_output.find(self.tool_call_prefix)) + content = model_output[:content_index] # .rstrip() + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # If no delta text, return None unless it's an EOS token after tool + # calls + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + # We check for tool calls in the text even if is_tool_call_started + # is False because it might have been reset after processing all + # tools + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated + # prev_tool_call_arr + if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + # Check if all tool calls are closed + open_calls = ( + current_text.count(self.tool_call_start_token) - + current_text.count(self.tool_call_end_token)) + if open_calls == 0: + # Return empty delta message to allow finish_reason + # processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if this is the first call (reset state if needed) + if not previous_text: + self._reset_streaming_state() + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + + # Check if there are more tool calls + tool_starts_count = current_text.count( + self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # No more tool calls + self.is_tool_call_started = False + # Continue processing next tool + return None + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if (self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[:delta_text.index( + self.tool_call_start_token)] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if (current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == ""): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + # Count tool calls we've seen vs processed + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # We're in a tool call, find the current tool call portion + # Need to find the correct tool call based on current_tool_index + tool_starts: list[int] = [] + idx = 0 + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + # No more tool calls to process yet + return None + + tool_start_idx = tool_starts[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + tool_end_idx = current_text.find(self.tool_call_end_token, + tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[tool_start_idx:tool_end_idx + + len(self.tool_call_end_token)] + + # Looking for function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = (tool_text.find(self.tool_call_prefix) + + len(self.tool_call_prefix)) + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_string_id = self._generate_tool_call_id() + self.header_sent = True + self.in_function = True + + # IMPORTANT: Add to prev_tool_call_arr immediately when we + # detect a tool call. This ensures + # finish_reason="tool_calls" even if parsing isn't complete + already_added = any( + tool.get("name") == self.current_function_name + for tool in self.prev_tool_call_arr) + if not already_added: + self.prev_tool_call_arr.append({ + "name": self.current_function_name, + "arguments": + "{}", # Placeholder, will be updated later + }) + + # Send header with function info + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_string_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments=""), + type="function", + ) + ]) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if (not self.json_started + and self.parameter_prefix not in delta_text): + self.json_started = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ]) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.function_end_token in tool_text: + # Close JSON + self.json_closed = True + + # Extract the complete tool call to update prev_tool_call_arr + # with final arguments. Find the function content + func_start = (tool_text.find(self.tool_call_prefix) + + len(self.tool_call_prefix)) + func_content_end = tool_text.find(self.function_end_token, + func_start) + if func_content_end != -1: + func_content = tool_text[func_start:func_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_xml_function_call( + func_content, request.tools if request else None) + if parsed_tool: + # Update existing entry in prev_tool_call_arr with + # complete arguments + for i, tool in enumerate(self.prev_tool_call_arr): + if (tool.get("name") == + parsed_tool.function.name): + self.prev_tool_call_arr[i]["arguments"] = ( + parsed_tool.function.arguments) + break + except Exception: + pass # Ignore parsing errors during streaming + + result = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ]) + + # Reset state for next tool + self.in_function = False + self.json_closed = True + + return result + + # Look for parameters + # Count how many complete parameters we have processed + complete_params = tool_text.count(self.parameter_end_token) + + # Check if we should start a new parameter + if not self.in_param and self.param_count < complete_params: + # Find the unprocessed parameter + # Count parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find( + self.parameter_end_token) + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Build complete JSON fragment for this parameter + if self.param_count == 0: + json_fragment = ( + '"' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + else: + json_fragment = ( + ', "' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value + if self.in_param: + if self.parameter_end_token in delta_text: + # End of parameter + end_idx = delta_text.find(self.parameter_end_token) + value_chunk = delta_text[:end_idx] + + # Skip past > if at start + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if (not self.current_param_value + and value_chunk.startswith("\n")): + value_chunk = value_chunk[1:] + + # Calculate incremental JSON + full_value = self.current_param_value + value_chunk + prev_escaped = (json.dumps(self.current_param_value)[1:-1] + if self.current_param_value else "") + full_escaped = json.dumps(full_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + self.in_param = False + self.current_param_value = "" + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"'), + ) + ]) + else: + # Continue accumulating value + value_chunk = delta_text + + # Handle first chunk after param name + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if (not self.current_param_value + and value_chunk.startswith("\n")): + value_chunk = value_chunk[1:] + + if value_chunk: + # Stream the escaped delta + prev_escaped = (json.dumps( + self.current_param_value)[1:-1] + if self.current_param_value else "") + self.current_param_value += value_chunk + full_escaped = json.dumps( + self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + if delta_escaped: + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped), + ) + ]) + + return None