mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:06:15 +08:00
[Chore]: Stream tokens vs characters in tool call parser tests (#26513)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
23ad820553
commit
3b96f85c36
12
tests/entrypoints/openai/tool_parsers/conftest.py
Normal file
12
tests/entrypoints/openai/tool_parsers/conftest.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def default_tokenizer() -> AnyTokenizer:
|
||||||
|
return AutoTokenizer.from_pretrained("gpt2")
|
||||||
@ -2,17 +2,15 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def parser():
|
def parser(default_tokenizer: AnyTokenizer):
|
||||||
# Use a small tokenizer for testing
|
return Llama3JsonToolParser(default_tokenizer)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
||||||
return Llama3JsonToolParser(tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tool_calls_simple(parser):
|
def test_extract_tool_calls_simple(parser):
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
# Test cases similar to pythonic parser but with Llama4 specific format
|
# Test cases similar to pythonic parser but with Llama4 specific format
|
||||||
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
||||||
@ -63,10 +64,9 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [True, False])
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
def test_no_tool_call(streaming: bool):
|
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
model_output = "How can I help you today?"
|
model_output = "How can I help you today?"
|
||||||
|
|
||||||
@ -205,11 +205,13 @@ TEST_CASES = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||||
def test_tool_call(
|
def test_tool_call(
|
||||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
streaming: bool,
|
||||||
|
model_output: str,
|
||||||
|
expected_tool_calls: list[FunctionCall],
|
||||||
|
default_tokenizer: AnyTokenizer,
|
||||||
):
|
):
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
content, tool_calls = run_tool_extraction(
|
content, tool_calls = run_tool_extraction(
|
||||||
@ -222,10 +224,9 @@ def test_tool_call(
|
|||||||
assert actual.function == expected
|
assert actual.function == expected
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_tool_call_with_large_steps():
|
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
model_output_deltas = [
|
model_output_deltas = [
|
||||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||||
@ -245,11 +246,10 @@ def test_streaming_tool_call_with_large_steps():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [False])
|
@pytest.mark.parametrize("streaming", [False])
|
||||||
def test_regex_timeout_handling(streaming: bool):
|
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||||
"""test regex timeout is handled gracefully"""
|
"""test regex timeout is handled gracefully"""
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||||
@ -68,9 +69,10 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [True, False])
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
def test_no_tool_call(streaming: bool):
|
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||||
mock_tokenizer = MagicMock()
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
default_tokenizer
|
||||||
|
)
|
||||||
model_output = "How can I help you today?"
|
model_output = "How can I help you today?"
|
||||||
|
|
||||||
content, tool_calls = run_tool_extraction(
|
content, tool_calls = run_tool_extraction(
|
||||||
@ -183,10 +185,14 @@ TEST_CASES = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||||
def test_tool_call(
|
def test_tool_call(
|
||||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
streaming: bool,
|
||||||
|
model_output: str,
|
||||||
|
expected_tool_calls: list[FunctionCall],
|
||||||
|
default_tokenizer: AnyTokenizer,
|
||||||
):
|
):
|
||||||
mock_tokenizer = MagicMock()
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
default_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
content, tool_calls = run_tool_extraction(
|
content, tool_calls = run_tool_extraction(
|
||||||
tool_parser, model_output, streaming=streaming
|
tool_parser, model_output, streaming=streaming
|
||||||
@ -199,9 +205,10 @@ def test_tool_call(
|
|||||||
assert actual.function == expected
|
assert actual.function == expected
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_tool_call_with_large_steps():
|
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||||
mock_tokenizer = MagicMock()
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
default_tokenizer
|
||||||
|
)
|
||||||
model_output_deltas = [
|
model_output_deltas = [
|
||||||
"<function_calls>get_weather(city='San",
|
"<function_calls>get_weather(city='San",
|
||||||
" Francisco', metric='celsius')\n"
|
" Francisco', metric='celsius')\n"
|
||||||
@ -221,10 +228,11 @@ def test_streaming_tool_call_with_large_steps():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [False])
|
@pytest.mark.parametrize("streaming", [False])
|
||||||
def test_regex_timeout_handling(streaming: bool):
|
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||||
"""test regex timeout is handled gracefully"""
|
"""test regex timeout is handled gracefully"""
|
||||||
mock_tokenizer = MagicMock()
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
|
default_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||||
@ -60,10 +61,9 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [True, False])
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
def test_no_tool_call(streaming: bool):
|
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
model_output = "How can I help you today?"
|
model_output = "How can I help you today?"
|
||||||
|
|
||||||
@ -165,11 +165,13 @@ TEST_CASES = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||||
def test_tool_call(
|
def test_tool_call(
|
||||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
streaming: bool,
|
||||||
|
model_output: str,
|
||||||
|
expected_tool_calls: list[FunctionCall],
|
||||||
|
default_tokenizer: AnyTokenizer,
|
||||||
):
|
):
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
content, tool_calls = run_tool_extraction(
|
content, tool_calls = run_tool_extraction(
|
||||||
@ -183,10 +185,9 @@ def test_tool_call(
|
|||||||
assert actual.function == expected
|
assert actual.function == expected
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_tool_call_with_large_steps():
|
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
model_output_deltas = [
|
model_output_deltas = [
|
||||||
"[get_weather(city='San",
|
"[get_weather(city='San",
|
||||||
@ -207,11 +208,10 @@ def test_streaming_tool_call_with_large_steps():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("streaming", [False])
|
@pytest.mark.parametrize("streaming", [False])
|
||||||
def test_regex_timeout_handling(streaming: bool):
|
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||||
"""test regex timeout is handled gracefully"""
|
"""test regex timeout is handled gracefully"""
|
||||||
mock_tokenizer = MagicMock()
|
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||||
mock_tokenizer
|
default_tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
|
||||||
class StreamingToolReconstructor:
|
class StreamingToolReconstructor:
|
||||||
@ -110,12 +111,32 @@ def run_tool_extraction_nonstreaming(
|
|||||||
return tool_parser.extract_tool_calls(model_output, request)
|
return tool_parser.extract_tool_calls(model_output, request)
|
||||||
|
|
||||||
|
|
||||||
|
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]:
|
||||||
|
# Split a string into a series of deltas using the provided tokenizer. Each
|
||||||
|
# delta will be the string equivalent of a single token.
|
||||||
|
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
previously_decoded_text = ""
|
||||||
|
deltas = []
|
||||||
|
for i in range(1, len(token_ids) + 1):
|
||||||
|
current_tokens = token_ids[:i]
|
||||||
|
current_text = tokenizer.decode(current_tokens)
|
||||||
|
new_text = current_text[len(previously_decoded_text) :]
|
||||||
|
previously_decoded_text = current_text
|
||||||
|
deltas.append(new_text)
|
||||||
|
return deltas
|
||||||
|
|
||||||
|
|
||||||
def run_tool_extraction_streaming(
|
def run_tool_extraction_streaming(
|
||||||
tool_parser: ToolParser,
|
tool_parser: ToolParser,
|
||||||
model_deltas: Iterable[str],
|
model_deltas: Iterable[str],
|
||||||
request: ChatCompletionRequest | None = None,
|
request: ChatCompletionRequest | None = None,
|
||||||
assert_one_tool_per_delta: bool = True,
|
assert_one_tool_per_delta: bool = True,
|
||||||
) -> StreamingToolReconstructor:
|
) -> StreamingToolReconstructor:
|
||||||
|
if isinstance(model_deltas, str):
|
||||||
|
model_deltas = split_string_into_token_deltas(
|
||||||
|
tool_parser.model_tokenizer, model_deltas
|
||||||
|
)
|
||||||
|
|
||||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||||
reconstructor = StreamingToolReconstructor(
|
reconstructor = StreamingToolReconstructor(
|
||||||
assert_one_tool_per_delta=assert_one_tool_per_delta
|
assert_one_tool_per_delta=assert_one_tool_per_delta
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user