diff --git a/requirements/common.txt b/requirements/common.txt index 8b9e6b935bd20..f18560b98d16c 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -46,6 +46,7 @@ scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects +ijson # Required for mistral streaming tool parser setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py new file mode 100644 index 0000000000000..e5deb7f40eb35 --- /dev/null +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -0,0 +1,847 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Generator + +import partial_json_parser +import pytest +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.request import InstructRequest +from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser +from vllm.tokenizers import ( + MistralTokenizer, + TokenizerLike, + get_tokenizer, +) +from vllm.tokenizers.detokenizer_utils import detokenize_incrementally + + +@pytest.fixture(scope="module") +def mistral_pre_v11_tokenizer(): + MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" + return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral") + + +@pytest.fixture +def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer): + return MistralToolParser(mistral_pre_v11_tokenizer) + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall] | list[DeltaToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + if isinstance(actual_tool_call, ToolCall): + assert actual_tool_call.type == "function" + elif isinstance(actual_tool_call, DeltaToolCall): + assert actual_tool_call.function is not None + assert actual_tool_call.function.name is not None + assert actual_tool_call.function.arguments is not None + assert actual_tool_call.function is not None + assert actual_tool_call.function.name == expected_tool_call.function.name, ( + f"got wrong function name:${actual_tool_call.function.name}" + ) + assert ( + actual_tool_call.function.arguments == expected_tool_call.function.arguments + ), f"got wrong function argument:${actual_tool_call.function.arguments}" + + +def fix_tool_call_tokenization( + tokens: list[int], + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: TokenizerLike, +): + """ + Replaces the textual token sequence for [TOOL_CALLS] + with its single special token ID. + """ + textual_tool_call_token_ids = mistral_tokenizer.encode( + text=mistral_tool_parser.bot_token, + add_special_tokens=False, + ) + # textual_tool_call_token_ids must not contain special tokens like bos, eos etc + special_tool_call_token_ids = [mistral_tool_parser.bot_token_id] + + # If the input is too short to contain the sequence, no replacement is possible + if not tokens or len(tokens) < len(textual_tool_call_token_ids): + return tokens + + result_tokens = [] + i = 0 + target_len = len(textual_tool_call_token_ids) + + while i < len(tokens): + # Check if the slice from the current position matches the target sequence + if tokens[i : i + target_len] == textual_tool_call_token_ids: + # If it matches, add the replacement and jump the index forward + result_tokens.extend(special_tool_call_token_ids) + i += target_len + else: + # Otherwise, just add the current token and move to the next one + result_tokens.append(tokens[i]) + i += 1 + + return result_tokens + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: TokenizerLike, + model_output: str | None, + tools: list[tuple[str, str]] | None, +) -> Generator[DeltaMessage, None, None]: + if ( + isinstance(mistral_tokenizer, MistralTokenizer) + and mistral_tokenizer.version >= 11 + ): + # With the newer versions of the tokenizer, + # we cannot tokenize free text + # so we need to create a list of messages to get tokenized + assert tools is not None + assistant_msg = AssistantMessage( + tool_calls=[ + ToolCall( + function=FunctionCall( + name=name, + arguments=arg, + ) + ) + for (name, arg) in tools + ], + ) + request = InstructRequest( + messages=[assistant_msg], + ) + all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens + else: + # Older versions of the tokenizer are + # able to encode directly the model's output (free text) into tokens + assert model_output is not None + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_tool_parser, mistral_tokenizer + ) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer), + spaces_between_special_tokens=True, + ) + ) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + None, + ), + ], +) +def test_extract_tool_calls_pre_v11_tokenizer( + mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="multiply", arguments=json.dumps({"a": 3, "b": 6}) + ) + ), + ], + None, + ), + ], +) +def test_extract_tool_calls( + mistral_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def _test_extract_tool_calls_streaming( + tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content +): + other_content: str = "" + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[str | None] = [] + + for delta_message in stream_delta_message_generator( + tool_parser, tokenizer, model_output, tools + ): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + assert len(tool_parser.prev_tool_call_arr) > 0 + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR + ), + ), + ) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs + ) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": "3", "b": "4"}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + None, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_add_strings", + "multiple_tools", + ], + argnames=["tools", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + [("add", '{"a": 3, "b": 4}')], + # [TOOL_CALLS]add{"a": 3, "b": 4} + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + [("add_two_strings", '{"a": "3", "b": "4"}')], + # [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"} + [ + ToolCall( + function=FunctionCall( + name="add_two_strings", + arguments=json.dumps({"a": "3", "b": "4"}), + ) + ) + ], + "", + ), + ( + [ + ("add", '{"a": 3.5, "b": 4}'), + ( + "get_current_weather", + '{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501 + ), + ], + # [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + tools, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + None, + tools, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + "content_before_tool", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="multiply", arguments=json.dumps({"a": 3, "b": 6}) + ) + ), + ], + "", + ), + ( + # Additional content should not be after the tool calls + """bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + "bla", + ), + ], +) +def test_extract_tool_calls_streaming_one_chunk( + mistral_tool_parser, + mistral_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + if isinstance(mistral_tokenizer, MistralTokenizer): + all_token_ids = mistral_tokenizer.encode(model_output) + else: + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_tool_parser, mistral_tokenizer + ) + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=model_output, + delta_text=model_output, + previous_token_ids=[], + current_token_ids=all_token_ids, + delta_token_ids=all_token_ids, + request=None, + ) # type: ignore[arg-type] + assert isinstance(delta_message, DeltaMessage) + assert len(delta_message.tool_calls) == len(expected_tool_calls) + + assert_tool_calls(delta_message.tool_calls, expected_tool_calls) + + if delta_message.content is None: + assert expected_content == "" + else: + assert delta_message.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": "3", "b": "4"}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer): + all_token_ids = mistral_pre_v11_tokenizer.encode(model_output) + else: + all_token_ids = mistral_pre_v11_tokenizer.encode( + model_output, add_special_tokens=False + ) + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer + ) + + delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=model_output, + delta_text=model_output, + previous_token_ids=[], + current_token_ids=all_token_ids, + delta_token_ids=all_token_ids, + request=None, + ) # type: ignore[arg-type] + assert isinstance(delta_message, DeltaMessage) + assert len(delta_message.tool_calls) == len(expected_tool_calls) + + assert_tool_calls(delta_message.tool_calls, expected_tool_calls) + + if delta_message.content is None: + assert expected_content == "" + else: + assert delta_message.content == expected_content diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 7584b903156b7..de7284a309c53 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = { "supports_parallel": True, "extended": True, }, - "mistral": { + "mistral-7b": { "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ "--enforce-eager", @@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = { "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally.", + "supports_parallel": True, + }, + "mistral-small-3.2": { + "model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506", + "arguments": [ + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "mistral", + "--tokenizer-mode", + "mistral", + "--config-format", + "mistral", + "--load-format", + "mistral", + "--tensor-parallel-size", + "4", + '--ignore-patterns="consolidated.safetensors"', + ], + "system_prompt": "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally.", + "supports_parallel": True, + "extended": True, }, # FIXME: This test currently fails, need to debug why. # "granite20b": { diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index b89db60545abd..aa5089ffe84d7 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -3,12 +3,12 @@ import json from collections.abc import Sequence +from enum import Enum, auto from random import choices from string import ascii_letters, digits -import partial_json_parser +import ijson import regex as re -from partial_json_parser.core.options import Allow from pydantic import Field from vllm.entrypoints.openai.protocol import ( @@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.tokenizers import MistralTokenizer, TokenizerLike @@ -32,6 +31,22 @@ logger = init_logger(__name__) ALPHANUMERIC = ascii_letters + digits +class StreamingState(Enum): + """Enum for tracking the current streaming parsing state.""" + + WAITING_FOR_TOOL_START = auto() + WAITING_FOR_TOOL_KEY = ( + auto() + ) # waiting for the "name" or "arguments" key to be complete + PARSING_NAME = auto() + PARSING_NAME_COMPLETED = auto() + WAITING_FOR_ARGUMENTS_START = auto() + PARSING_ARGUMENTS = auto() + PARSING_ARGUMENTS_COMPLETED = auto() + TOOL_COMPLETE = auto() + ALL_TOOLS_COMPLETE = auto() + + class MistralToolCall(ToolCall): id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id()) @@ -46,8 +61,8 @@ class MistralToolCall(ToolCall): return id.isalnum() and len(id) == 9 -def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool: - return ( +def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: + return not ( isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 ) @@ -69,16 +84,22 @@ class MistralToolParser(ToolParser): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[ - str - ] = [] # map what has been streamed for each tool so far to a list + self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START + + # For streaming pre v11 tokenizer tool calls + self.current_tool_name: str | None = None + self.current_tool_mistral_id: str | None = None + self.starting_new_tool = False + if _is_pre_v11_tokeniser(self.model_tokenizer): + self.parse_coro = ijson.parse_coro( + self.update_stream_state_pre_v11_tokenizer() + ) + self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) - if _is_fn_name_regex_support(self.model_tokenizer): + if not _is_pre_v11_tokeniser(self.model_tokenizer): self.fn_name_regex = re.compile( r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL ) @@ -131,18 +152,19 @@ class MistralToolParser(ToolParser): # jsons is difficult try: if self.fn_name_regex: - matches = self.fn_name_regex.findall(tool_content) - function_call_arr = [] - for match in matches: - fn_name = match[0] - args = match[1] + for single_tool_content in model_output.split(self.bot_token): + matches = self.fn_name_regex.findall(single_tool_content) - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append( - {"name": fn_name, "arguments": json.loads(args)} - ) + for match in matches: + fn_name = match[0] + args = match[1] + + # fn_name is encoded outside serialized json dump + # only arguments are serialized + function_call_arr.append( + {"name": fn_name, "arguments": json.loads(args)} + ) else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: @@ -193,198 +215,372 @@ class MistralToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool - if self.bot_token not in current_text: + if self.bot_token_id not in current_token_ids: + # if the tool call token is not in the tokens generated so far, + # append output to contents since it's not a tool return DeltaMessage(content=delta_text) - # if the tool call token ID IS in the tokens generated so far, that + # if the tool call token IS in the tokens generated so far, that # means we're parsing as tool calls now - - # handle if we detected the BOT token which means the start of tool - # calling - if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1: - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags + if _is_pre_v11_tokeniser(self.model_tokenizer): + return self._extract_tool_calls_streaming_pre_v11_tokenizer( + delta_text=delta_text, + delta_token_ids=delta_token_ids, ) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug("not enough tokens to parse into JSON yet") - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = ( - tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} - ) - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif ( - len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 - ): - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: str | None = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], "" - ) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name - ).model_dump(exclude_none=True), - ) - ] - ) - self.current_tool_name_sent = True - else: - delta = None - - # now we know we're on the same tool call and we're streaming - # arguments else: - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments" + return self._extract_tool_calls_streaming( + delta_text=delta_text, delta_token_ids=delta_token_ids ) - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("'", '"') - if '"}' in new_text: - new_text = new_text[: new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset mid-arguments" - ) - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[ - :-2 - ] - logger.debug("finding %s in %s", new_text, cur_arguments_json) - - if new_text not in cur_arguments_json: - return None - arguments_delta = cur_arguments_json[ - : cur_arguments_json.rindex(new_text) + len(new_text) - ] - logger.debug( - "First tokens in arguments received: %s", arguments_delta - ) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) - logger.debug( - "Searching for diff between \n%s\n%s", - cur_args_json, - prev_args_json, - ) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json - ) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta - except Exception: logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction error" - ) return None + + def _extract_tool_calls_streaming( + self, + delta_text: str, + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extracts tool calls for Mistral models + doing tool calls of the following format: + `[TOOL_CALLS]add{"a": 3.5, "b": 4}` + """ + additional_content: str = "" + if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: + # this is the first tool call + assert self.bot_token_id in delta_token_ids + if not delta_text.startswith(self.bot_token): + additional_content += delta_text.split(self.bot_token)[0] + delta_text = self.bot_token + "".join( + delta_text.split(self.bot_token)[1:] + ) + + delta_tool_calls = self._generate_delta_tool_call(delta_text) + if not additional_content and len(delta_tool_calls) == 0: + if self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS, + StreamingState.PARSING_ARGUMENTS_COMPLETED, + StreamingState.TOOL_COMPLETE, + StreamingState.ALL_TOOLS_COMPLETE, + ]: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage() + else: + # return None when the tool is not likely to be finished + # This can occur when the name is being parsed for example + # and we wait for the name to be complete + # before sending the function name + return None + + delta = DeltaMessage() + if additional_content: + delta.content = additional_content + if len(delta_tool_calls) > 0: + delta.tool_calls = delta_tool_calls + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining its final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if delta_tool_calls and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + return delta + + def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]: + if delta_text == "" or delta_text is None: + return [] + delta_function_name = None + tool_id = None + if self.streaming_state not in [ + StreamingState.PARSING_NAME, + StreamingState.PARSING_ARGUMENTS, + ] and delta_text.startswith(self.bot_token): + self.current_tool_id += 1 + self.streaming_state = StreamingState.PARSING_NAME + delta_text = delta_text.replace(self.bot_token, "", 1) + if self.streaming_state == StreamingState.PARSING_NAME: + if self.current_tool_name is None: + self.current_tool_name = "" + # The name stops where the arguments start + # And the arguments start with the `{` char + if "{" in delta_text: + tool_id = MistralToolCall.generate_random_id() + delta_function_name = delta_text.split("{")[0] + self.current_tool_name += delta_function_name + delta_text = delta_text[len(delta_function_name) :] + self.streaming_state = StreamingState.PARSING_ARGUMENTS + else: + # we want to send the tool name once it's complete + self.current_tool_name += delta_text + return [] + if self.streaming_state == StreamingState.PARSING_ARGUMENTS: + next_function_text = None + if self.bot_token in delta_text: + # current tool call is over + delta_arguments = "" + delta_arguments += delta_text.split(self.bot_token)[0] + next_function_text = delta_text[len(delta_arguments) :] + self.streaming_state = StreamingState.TOOL_COMPLETE + else: + delta_arguments = delta_text + ret = [] + if self.current_tool_name or delta_arguments: + ret += [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=self.current_tool_name, arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + self.current_tool_name = None + if next_function_text: + ret += self._generate_delta_tool_call(next_function_text) + return ret + # Should not happen + return [] + + @ijson.coroutine + def update_stream_state_pre_v11_tokenizer(self): + while True: + (prefix, event, value) = yield + + if prefix == "item" and event == "start_map": + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + if prefix == "item" and event == "map_key" and value == "name": + self.streaming_state = StreamingState.PARSING_NAME + if prefix == "item.name" and event == "string": + self.current_tool_name = value + self.streaming_state = StreamingState.PARSING_NAME_COMPLETED + if prefix == "item" and event == "map_key" and value == "arguments": + self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START + if prefix == "item.arguments" and event == "start_map": + self.streaming_state = StreamingState.PARSING_ARGUMENTS + if prefix == "item.arguments" and event == "end_map": + self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED + if prefix == "item" and event == "end_map": + self.streaming_state = StreamingState.TOOL_COMPLETE + if prefix == "" and event == "end_array": + self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE + + def _extract_tool_calls_streaming_pre_v11_tokenizer( + self, + delta_text: str, + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extracts tool calls for Mistral models + doing tool calls of the following format: + `[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}` + """ + assert self.parse_coro is not None + content = None + delta_tool_calls: list[DeltaToolCall] = [] + current_tool_call: DeltaToolCall = DeltaToolCall( + index=self.current_tool_id, type="function" + ) + current_tool_call_modified = False + if self.bot_token_id in delta_token_ids: + # this is the first tool call + if not delta_text.startswith(self.bot_token): + content = delta_text.split(self.bot_token)[0] + delta_text = "".join(delta_text.split(self.bot_token)[1:]) + + # Cut smartly the delta text to catch the ijson events + # as ijson does not give us the index in the text at each event. + # We need to cut so that we know + # where in the text the events are emitted from. + while len(delta_text) > 0: + streaming_state_before_parse = self.streaming_state + + if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + ) + elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY: + # Wait until another key is sent + # or the current tool is completed + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_colon=1, + stop_after_opening_curly_braces=1, + # if the tool ends, we want to separate + # at the start of the next tool + ) + elif self.streaming_state == StreamingState.PARSING_NAME: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_comma=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + ) + elif self.streaming_state == StreamingState.PARSING_ARGUMENTS: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_closing_curly_braces=1, + # we could be more clever + # by listening to item.arguments.* start_map events + # and know how many curly braces we can allow + ) + elif self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS_COMPLETED, + StreamingState.PARSING_NAME_COMPLETED, + ]: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_closing_curly_braces=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.TOOL_COMPLETE: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE: + content = delta_text + delta_text = "" + else: + delta_to_be_parsed = delta_text + delta_text = "" + + if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE: + self.parse_coro.send(delta_to_be_parsed.encode("utf-8")) + + # Given the parsed text and the possible streaming state change, + # let's add to the tool delta + if ( + (streaming_state_before_parse != self.streaming_state) + and streaming_state_before_parse + in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE] + and self.streaming_state + not in [ + StreamingState.ALL_TOOLS_COMPLETE, + StreamingState.TOOL_COMPLETE, + StreamingState.WAITING_FOR_TOOL_START, + ] + ): + # starting a new tool call + if current_tool_call_modified: + if self.current_tool_mistral_id is not None: + current_tool_call.id = self.current_tool_mistral_id + self.current_tool_mistral_id = None + delta_tool_calls.append(current_tool_call) + current_tool_call_modified = False + self.current_tool_id += 1 + self.current_tool_mistral_id = MistralToolCall.generate_random_id() + current_tool_call = DeltaToolCall( + index=self.current_tool_id, + type="function", + ) + if current_tool_call.function is None: + current_tool_call.function = DeltaFunctionCall() + + if self.current_tool_name is not None: + # we have the complete tool name + current_tool_call_modified = True + current_tool_call.function.name = self.current_tool_name + self.current_tool_name = None + if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED: + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + if self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS, + StreamingState.PARSING_ARGUMENTS_COMPLETED, + ]: + if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED: + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + # the delta_to_be_parsed is part of arguments. + current_tool_call_modified = True + if current_tool_call.function.arguments is None: + current_tool_call.function.arguments = delta_to_be_parsed + else: + current_tool_call.function.arguments += delta_to_be_parsed + if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS: + # It's the first chunk of arg. let's lstrip it + current_tool_call.function.arguments = ( + current_tool_call.function.arguments.lstrip() + ) + + if current_tool_call_modified: + if self.current_tool_mistral_id is not None: + current_tool_call.id = self.current_tool_mistral_id + self.current_tool_mistral_id = None + delta_tool_calls.append(current_tool_call) + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if delta_tool_calls and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if content or len(delta_tool_calls) > 0: + delta_message = DeltaMessage() + if content: + delta_message.content = content + if len(delta_tool_calls) > 0: + delta_message.tool_calls = delta_tool_calls + return delta_message + else: + if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE: + return DeltaMessage() + else: + return None + + def _split_delta( + self, + delta_text: str, + stop_after_quotes: int = -1, + stop_after_opening_curly_braces: int = -1, + stop_after_closing_curly_braces: int = -1, + stop_after_closing_brackets: int = -1, + stop_after_colon: int = -1, + stop_after_comma=-1, + ) -> tuple[str, str]: + delta_to_be_parsed = "" + for i, c in enumerate(delta_text): + if c in ['"', "'"]: + delta_to_be_parsed += c + stop_after_quotes -= 1 + if stop_after_quotes == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "{": + delta_to_be_parsed += c + stop_after_opening_curly_braces -= 1 + if stop_after_opening_curly_braces == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "}": + delta_to_be_parsed += c + stop_after_closing_curly_braces -= 1 + if stop_after_closing_curly_braces == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "]": + delta_to_be_parsed += c + stop_after_closing_brackets -= 1 + if stop_after_closing_brackets == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == ":": + delta_to_be_parsed += c + stop_after_colon -= 1 + if stop_after_colon == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == ",": + delta_to_be_parsed += c + stop_after_comma -= 1 + if stop_after_comma == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + else: + delta_to_be_parsed += c + + return (delta_to_be_parsed, "")