mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 22:54:40 +08:00
[Bugfix] Add support for <tool_call> format in streaming mode for XLAM Tool Parser (#22769)
Signed-off-by: Devon Peroutky <devon@kindo.ai>
This commit is contained in:
parent
1cb39dbcdd
commit
422e793fa6
@ -2,12 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
|
||||
@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
xlam_tool_parser: xLAMToolParser,
|
||||
xlam_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = xlam_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = (detokenize_incrementally(
|
||||
tokenizer=xlam_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
))
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
"single_tool_with_think_tag",
|
||||
"single_tool_with_json_code_block",
|
||||
"single_tool_with_tool_calls_tag",
|
||||
"single_tool_with_tool_call_xml_tags",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
],
|
||||
"I'll check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(xlam_tool_parser, model_output,
|
||||
@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
|
||||
assert hasattr(result, "tool_calls")
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"parallel_tool_calls",
|
||||
"single_tool_with_think_tag",
|
||||
"single_tool_with_json_code_block",
|
||||
"single_tool_with_tool_calls_tag",
|
||||
"single_tool_with_tool_call_xml_tags",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"<think>I'll help you with that.</think>",
|
||||
),
|
||||
(
|
||||
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"I can help with that.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
xlam_tool_parser,
|
||||
xlam_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
xlam_tool_parser, xlam_tokenizer, model_output, request):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
assert len(chunks) >= 3
|
||||
|
||||
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
|
||||
header_found = False
|
||||
expected_first_tool = expected_tool_calls[0]
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||
header_found = True
|
||||
assert (chunk.tool_calls[0].function.name ==
|
||||
expected_first_tool.function.name)
|
||||
assert chunk.tool_calls[0].type == "function"
|
||||
# Arguments may be empty initially or None
|
||||
if chunk.tool_calls[0].function.arguments is not None:
|
||||
# If present, should be empty string initially
|
||||
assert chunk.tool_calls[0].function.arguments == ""
|
||||
break
|
||||
assert header_found
|
||||
|
||||
# Should have chunks with incremental arguments
|
||||
arg_chunks = []
|
||||
for chunk in chunks:
|
||||
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
|
||||
and chunk.tool_calls[0].function.arguments != ""
|
||||
and chunk.tool_calls[0].index ==
|
||||
0 # Only collect arguments from the first tool call
|
||||
):
|
||||
arg_chunks.append(chunk.tool_calls[0].function.arguments)
|
||||
|
||||
# Arguments should be streamed incrementally
|
||||
assert len(arg_chunks) > 1
|
||||
|
||||
# Concatenated arguments should form valid JSON for the first tool call
|
||||
full_args = "".join(arg_chunks)
|
||||
parsed_args = json.loads(full_args)
|
||||
expected_args = json.loads(expected_first_tool.function.arguments)
|
||||
assert parsed_args == expected_args
|
||||
|
||||
@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
# Simplify detection: if it begins with "[" treat it as a function call
|
||||
is_function_call = (current_text.strip().startswith("["))
|
||||
# First, check for a definitive start of a tool call block.
|
||||
# This prevents premature parsing of incomplete output.
|
||||
stripped_text = current_text.strip()
|
||||
preprocessed_content, preprocessed_tool_calls = (
|
||||
self.preprocess_model_output(current_text))
|
||||
|
||||
# If not a function call, return normal content
|
||||
if not is_function_call:
|
||||
# For JSON code blocks, we need to detect them earlier, even if incomplete
|
||||
has_potential_json_block = ("```json" in current_text
|
||||
or "```\n[" in current_text
|
||||
or "[TOOL_CALLS]" in current_text
|
||||
or "<tool_call>" in current_text)
|
||||
|
||||
is_tool_call_block = (
|
||||
stripped_text.startswith("[")
|
||||
or stripped_text.startswith("<tool_call>")
|
||||
or stripped_text.startswith("[TOOL_CALLS]") or
|
||||
# Check if we have thinking tags with JSON-like content following
|
||||
("</think>[" in current_text) or
|
||||
# Check if the text contains a JSON array after preprocessing
|
||||
preprocessed_tool_calls is not None or
|
||||
# For JSON code blocks, detect early if we see enough structure
|
||||
(has_potential_json_block and '"name"' in current_text
|
||||
and '"arguments"' in current_text))
|
||||
|
||||
if not is_tool_call_block:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
|
||||
|
||||
# Try parsing as JSON to check for complete tool calls
|
||||
try:
|
||||
parsed_tools = json.loads(current_text)
|
||||
# Use preprocessed tool calls if available
|
||||
tool_calls_text = (preprocessed_tool_calls if
|
||||
preprocessed_tool_calls else current_text)
|
||||
parsed_tools = json.loads(tool_calls_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
# Update our tool array for next time
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser):
|
||||
return delta
|
||||
|
||||
# Use regex to identify tool calls in the output
|
||||
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
|
||||
search_text = (preprocessed_tool_calls
|
||||
if preprocessed_tool_calls else current_text)
|
||||
|
||||
# For JSON code blocks that aren't complete yet, try to extract the JSON content
|
||||
if not preprocessed_tool_calls and has_potential_json_block:
|
||||
# Try to extract the JSON array from within the code block
|
||||
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
|
||||
current_text)
|
||||
if json_match:
|
||||
potential_json = json_match.group(1).strip()
|
||||
# Use this as search text even if it's incomplete
|
||||
if potential_json.startswith("[") and (
|
||||
'"name"' in potential_json
|
||||
and '"arguments"' in potential_json):
|
||||
search_text = potential_json
|
||||
|
||||
# Try to find complete tool names first
|
||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||
name_matches = list(re.finditer(name_pattern, current_text))
|
||||
name_matches = list(re.finditer(name_pattern, search_text))
|
||||
tool_count = len(name_matches)
|
||||
|
||||
# If no tools found yet, return
|
||||
# If no complete tool names found, check for partial tool names
|
||||
if tool_count == 0:
|
||||
return None
|
||||
# Check if we're in the middle of parsing a tool name
|
||||
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
|
||||
partial_matches = list(
|
||||
re.finditer(partial_name_pattern, search_text))
|
||||
if partial_matches:
|
||||
# We have a partial tool name - not ready to emit yet
|
||||
return None
|
||||
else:
|
||||
# No tools found at all
|
||||
return None
|
||||
|
||||
# Ensure our state arrays are large enough
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
|
||||
# First, check for the empty arguments case: "arguments": {}
|
||||
empty_args_pattern = (
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
|
||||
empty_args_match = re.search(empty_args_pattern, current_text)
|
||||
empty_args_match = re.search(empty_args_pattern, search_text)
|
||||
|
||||
# Check if this tool has empty arguments
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
|
||||
|
||||
# Extract arguments for current tool using regex for non-empty arguments
|
||||
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
args_matches = list(re.finditer(args_pattern, current_text))
|
||||
args_matches = list(re.finditer(args_pattern, search_text))
|
||||
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
|
||||
# Handle transition between tools
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
|
||||
# Find where the arguments for our current tool end
|
||||
if not is_last_tool:
|
||||
# If we have more tools after this one, try to find the complete argument block
|
||||
next_tool_pos = current_text.find(
|
||||
"},{", args_matches[current_idx].start())
|
||||
if next_tool_pos != -1:
|
||||
args_end_pos = (next_tool_pos + 1
|
||||
) # +1 to include the '}'
|
||||
args_text = (current_text[args_matches[current_idx]
|
||||
.start():args_end_pos].
|
||||
split('"arguments":')[1].strip())
|
||||
# For multiple tools, extract only the arguments for the current tool
|
||||
if tool_count > 1:
|
||||
# Parse the entire JSON structure to properly extract arguments for each tool
|
||||
try:
|
||||
parsed_tools = json.loads(search_text)
|
||||
if isinstance(
|
||||
parsed_tools,
|
||||
list) and current_idx < len(parsed_tools):
|
||||
current_tool = parsed_tools[current_idx]
|
||||
if isinstance(current_tool.get("arguments"),
|
||||
dict):
|
||||
args_text = json.dumps(
|
||||
current_tool["arguments"])
|
||||
else:
|
||||
args_text = str(
|
||||
current_tool.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
# Fallback to regex-based extraction
|
||||
pass
|
||||
|
||||
# If arguments haven't been sent yet
|
||||
sent_args = self.streaming_state["sent_tools"][
|
||||
@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser):
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{").model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
return delta
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user