mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
Tool parser regex timeout handling (#18960)
Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
parent
7f21e8052b
commit
1dab4d5718
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -191,3 +191,27 @@ def test_streaming_tool_call_with_large_steps():
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -159,3 +159,27 @@ def test_streaming_tool_call_with_large_steps():
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -64,7 +65,19 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
if model_output.startswith("<|python_start|>"):
|
||||
model_output = model_output[len("<|python_start|>"):]
|
||||
model_output = model_output.replace("<|python_end|>", "")
|
||||
if not (self.TOOL_CALL_REGEX.match(model_output)):
|
||||
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -61,8 +62,18 @@ class PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
is_tool_call_pattern = False
|
||||
try:
|
||||
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
|
||||
model_output,
|
||||
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug("Regex timeout occurred when matching user input: %s",
|
||||
model_output)
|
||||
|
||||
if not (self.TOOL_CALL_REGEX.match(model_output)):
|
||||
if not is_tool_call_pattern:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
@ -119,6 +119,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -828,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# This is used to prevent the kernel from running out of memory.
|
||||
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
|
||||
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
|
||||
|
||||
# Regex timeout for use by the vLLM tool parsing plugins.
|
||||
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
|
||||
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user