diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 92ba1376e2002..f5f327ea068c6 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -191,3 +191,27 @@ def test_streaming_tool_call_with_large_steps(): assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL + + +@pytest.mark.parametrize("streaming", [False]) +def test_regex_timeout_handling(streaming: bool): + """test regex timeout is handled gracefully""" + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 + + # create a mock regex that raises TimeoutError + mock_regex = MagicMock() + mock_regex.match.side_effect = TimeoutError("Regex timeout") + + with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): + content, tool_calls = run_tool_extraction(tool_parser, + fake_problematic_input, + streaming=streaming) + + # should treat as regular text when regex times out + assert content == fake_problematic_input + assert len(tool_calls) == 0 + mock_regex.match.assert_called_once() diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index fbbbc1fb2a596..71f41ea7d93b4 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -159,3 +159,27 @@ def test_streaming_tool_call_with_large_steps(): assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL + + +@pytest.mark.parametrize("streaming", [False]) +def test_regex_timeout_handling(streaming: bool): + """test regex timeout is handled gracefully""" + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 + + # create a mock regex that raises TimeoutError + mock_regex = MagicMock() + mock_regex.match.side_effect = TimeoutError("Regex timeout") + + with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): + content, tool_calls = run_tool_extraction(tool_parser, + fake_problematic_input, + streaming=streaming) + + # should treat as regular text when regex times out + assert content == fake_problematic_input + assert len(tool_calls) == 0 + mock_regex.match.assert_called_once() diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 858c8db99fd29..323fb144181ea 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -7,6 +7,7 @@ from typing import Any, Union import regex as re from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -64,7 +65,19 @@ class Llama4PythonicToolParser(ToolParser): if model_output.startswith("<|python_start|>"): model_output = model_output[len("<|python_start|>"):] model_output = model_output.replace("<|python_end|>", "") - if not (self.TOOL_CALL_REGEX.match(model_output)): + + is_tool_call_pattern = False + try: + is_tool_call_pattern = self.TOOL_CALL_REGEX.match( + model_output, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + except TimeoutError: + logger.warning( + "Regex timeout occurred when matching tool call pattern.") + logger.debug("Regex timeout occurred when matching user input: %s", + model_output) + + if not is_tool_call_pattern: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 548ff39d1ca4f..bc5d15dcb82f4 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -8,6 +8,7 @@ from typing import Any, Union import regex as re from transformers import PreTrainedTokenizerBase +import vllm.envs as envs from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -61,8 +62,18 @@ class PythonicToolParser(ToolParser): """ Extract the tool calls from a complete model response. """ + is_tool_call_pattern = False + try: + is_tool_call_pattern = self.TOOL_CALL_REGEX.match( + model_output, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + except TimeoutError: + logger.warning( + "Regex timeout occurred when matching tool call pattern.") + logger.debug("Regex timeout occurred when matching user input: %s", + model_output) - if not (self.TOOL_CALL_REGEX.match(model_output)): + if not is_tool_call_pattern: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) diff --git a/vllm/envs.py b/vllm/envs.py index dc52bbd8edbc5..44baf5a189b43 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -119,6 +119,7 @@ if TYPE_CHECKING: VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 + VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 def get_default_cache_root(): @@ -828,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # This is used to prevent the kernel from running out of memory. "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + + # Regex timeout for use by the vLLM tool parsing plugins. + "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": + lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), } # --8<-- [end:env-vars-definition]