mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 04:22:14 +08:00
[Bugfix] Add support for <tool_call> format in streaming mode for XLAM Tool Parser (#22769)
Signed-off-by: Devon Peroutky <devon@kindo.ai>
This commit is contained in:
parent
1cb39dbcdd
commit
422e793fa6
@ -2,12 +2,17 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage, FunctionCall,
|
||||||
|
ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
|
|
||||||
# Use a common model that is likely to be available
|
# Use a common model that is likely to be available
|
||||||
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
|
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
|
||||||
@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
|||||||
assert actual_tool_call.function == expected_tool_call.function
|
assert actual_tool_call.function == expected_tool_call.function
|
||||||
|
|
||||||
|
|
||||||
|
def stream_delta_message_generator(
|
||||||
|
xlam_tool_parser: xLAMToolParser,
|
||||||
|
xlam_tokenizer: AnyTokenizer,
|
||||||
|
model_output: str,
|
||||||
|
request: Optional[ChatCompletionRequest] = None,
|
||||||
|
) -> Generator[DeltaMessage, None, None]:
|
||||||
|
all_token_ids = xlam_tokenizer.encode(model_output,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
previous_text = ""
|
||||||
|
previous_tokens = None
|
||||||
|
prefix_offset = 0
|
||||||
|
read_offset = 0
|
||||||
|
for i, delta_token in enumerate(all_token_ids):
|
||||||
|
delta_token_ids = [delta_token]
|
||||||
|
previous_token_ids = all_token_ids[:i]
|
||||||
|
current_token_ids = all_token_ids[:i + 1]
|
||||||
|
|
||||||
|
(new_tokens, delta_text, new_prefix_offset,
|
||||||
|
new_read_offset) = (detokenize_incrementally(
|
||||||
|
tokenizer=xlam_tokenizer,
|
||||||
|
all_input_ids=current_token_ids,
|
||||||
|
prev_tokens=previous_tokens,
|
||||||
|
prefix_offset=prefix_offset,
|
||||||
|
read_offset=read_offset,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
spaces_between_special_tokens=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
|
||||||
|
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text,
|
||||||
|
current_text,
|
||||||
|
delta_text,
|
||||||
|
previous_token_ids,
|
||||||
|
current_token_ids,
|
||||||
|
delta_token_ids,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
if delta_message:
|
||||||
|
yield delta_message
|
||||||
|
|
||||||
|
previous_text = current_text
|
||||||
|
previous_tokens = (previous_tokens +
|
||||||
|
new_tokens if previous_tokens else new_tokens)
|
||||||
|
prefix_offset = new_prefix_offset
|
||||||
|
read_offset = new_read_offset
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||||
model_output = "This is a test"
|
model_output = "This is a test"
|
||||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||||
@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
|||||||
"single_tool_with_think_tag",
|
"single_tool_with_think_tag",
|
||||||
"single_tool_with_json_code_block",
|
"single_tool_with_json_code_block",
|
||||||
"single_tool_with_tool_calls_tag",
|
"single_tool_with_tool_calls_tag",
|
||||||
|
"single_tool_with_tool_call_xml_tags",
|
||||||
],
|
],
|
||||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
argvalues=[
|
argvalues=[
|
||||||
@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
|||||||
],
|
],
|
||||||
"I'll check the weather for you.",
|
"I'll check the weather for you.",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
"I'll help you check the weather.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_extract_tool_calls(xlam_tool_parser, model_output,
|
def test_extract_tool_calls(xlam_tool_parser, model_output,
|
||||||
@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
|
|||||||
assert hasattr(result, "tool_calls")
|
assert hasattr(result, "tool_calls")
|
||||||
assert len(result.tool_calls) == 1
|
assert len(result.tool_calls) == 1
|
||||||
assert result.tool_calls[0].function.name == "get_current_weather"
|
assert result.tool_calls[0].function.name == "get_current_weather"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
ids=[
|
||||||
|
"parallel_tool_calls",
|
||||||
|
"single_tool_with_think_tag",
|
||||||
|
"single_tool_with_json_code_block",
|
||||||
|
"single_tool_with_tool_calls_tag",
|
||||||
|
"single_tool_with_tool_call_xml_tags",
|
||||||
|
],
|
||||||
|
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||||
|
argvalues=[
|
||||||
|
(
|
||||||
|
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Orlando",
|
||||||
|
"state": "FL",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
)),
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
"<think>I'll help you with that.</think>",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||||
|
[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
],
|
||||||
|
"I can help with that.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_tool_calls_streaming_incremental(
|
||||||
|
xlam_tool_parser,
|
||||||
|
xlam_tokenizer,
|
||||||
|
model_output,
|
||||||
|
expected_tool_calls,
|
||||||
|
expected_content,
|
||||||
|
):
|
||||||
|
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||||
|
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for delta_message in stream_delta_message_generator(
|
||||||
|
xlam_tool_parser, xlam_tokenizer, model_output, request):
|
||||||
|
chunks.append(delta_message)
|
||||||
|
|
||||||
|
# Should have multiple chunks
|
||||||
|
assert len(chunks) >= 3
|
||||||
|
|
||||||
|
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
|
||||||
|
header_found = False
|
||||||
|
expected_first_tool = expected_tool_calls[0]
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||||
|
header_found = True
|
||||||
|
assert (chunk.tool_calls[0].function.name ==
|
||||||
|
expected_first_tool.function.name)
|
||||||
|
assert chunk.tool_calls[0].type == "function"
|
||||||
|
# Arguments may be empty initially or None
|
||||||
|
if chunk.tool_calls[0].function.arguments is not None:
|
||||||
|
# If present, should be empty string initially
|
||||||
|
assert chunk.tool_calls[0].function.arguments == ""
|
||||||
|
break
|
||||||
|
assert header_found
|
||||||
|
|
||||||
|
# Should have chunks with incremental arguments
|
||||||
|
arg_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
|
||||||
|
and chunk.tool_calls[0].function.arguments != ""
|
||||||
|
and chunk.tool_calls[0].index ==
|
||||||
|
0 # Only collect arguments from the first tool call
|
||||||
|
):
|
||||||
|
arg_chunks.append(chunk.tool_calls[0].function.arguments)
|
||||||
|
|
||||||
|
# Arguments should be streamed incrementally
|
||||||
|
assert len(arg_chunks) > 1
|
||||||
|
|
||||||
|
# Concatenated arguments should form valid JSON for the first tool call
|
||||||
|
full_args = "".join(arg_chunks)
|
||||||
|
parsed_args = json.loads(full_args)
|
||||||
|
expected_args = json.loads(expected_first_tool.function.arguments)
|
||||||
|
assert parsed_args == expected_args
|
||||||
|
|||||||
@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
|
|||||||
"""
|
"""
|
||||||
Extract tool calls for streaming mode.
|
Extract tool calls for streaming mode.
|
||||||
"""
|
"""
|
||||||
# Simplify detection: if it begins with "[" treat it as a function call
|
# First, check for a definitive start of a tool call block.
|
||||||
is_function_call = (current_text.strip().startswith("["))
|
# This prevents premature parsing of incomplete output.
|
||||||
|
stripped_text = current_text.strip()
|
||||||
|
preprocessed_content, preprocessed_tool_calls = (
|
||||||
|
self.preprocess_model_output(current_text))
|
||||||
|
|
||||||
# If not a function call, return normal content
|
# For JSON code blocks, we need to detect them earlier, even if incomplete
|
||||||
if not is_function_call:
|
has_potential_json_block = ("```json" in current_text
|
||||||
|
or "```\n[" in current_text
|
||||||
|
or "[TOOL_CALLS]" in current_text
|
||||||
|
or "<tool_call>" in current_text)
|
||||||
|
|
||||||
|
is_tool_call_block = (
|
||||||
|
stripped_text.startswith("[")
|
||||||
|
or stripped_text.startswith("<tool_call>")
|
||||||
|
or stripped_text.startswith("[TOOL_CALLS]") or
|
||||||
|
# Check if we have thinking tags with JSON-like content following
|
||||||
|
("</think>[" in current_text) or
|
||||||
|
# Check if the text contains a JSON array after preprocessing
|
||||||
|
preprocessed_tool_calls is not None or
|
||||||
|
# For JSON code blocks, detect early if we see enough structure
|
||||||
|
(has_potential_json_block and '"name"' in current_text
|
||||||
|
and '"arguments"' in current_text))
|
||||||
|
|
||||||
|
if not is_tool_call_block:
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
|
|||||||
|
|
||||||
# Try parsing as JSON to check for complete tool calls
|
# Try parsing as JSON to check for complete tool calls
|
||||||
try:
|
try:
|
||||||
parsed_tools = json.loads(current_text)
|
# Use preprocessed tool calls if available
|
||||||
|
tool_calls_text = (preprocessed_tool_calls if
|
||||||
|
preprocessed_tool_calls else current_text)
|
||||||
|
parsed_tools = json.loads(tool_calls_text)
|
||||||
if isinstance(parsed_tools, list):
|
if isinstance(parsed_tools, list):
|
||||||
# Update our tool array for next time
|
# Update our tool array for next time
|
||||||
self.prev_tool_call_arr = parsed_tools
|
self.prev_tool_call_arr = parsed_tools
|
||||||
@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser):
|
|||||||
return delta
|
return delta
|
||||||
|
|
||||||
# Use regex to identify tool calls in the output
|
# Use regex to identify tool calls in the output
|
||||||
|
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
|
||||||
|
search_text = (preprocessed_tool_calls
|
||||||
|
if preprocessed_tool_calls else current_text)
|
||||||
|
|
||||||
|
# For JSON code blocks that aren't complete yet, try to extract the JSON content
|
||||||
|
if not preprocessed_tool_calls and has_potential_json_block:
|
||||||
|
# Try to extract the JSON array from within the code block
|
||||||
|
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
|
||||||
|
current_text)
|
||||||
|
if json_match:
|
||||||
|
potential_json = json_match.group(1).strip()
|
||||||
|
# Use this as search text even if it's incomplete
|
||||||
|
if potential_json.startswith("[") and (
|
||||||
|
'"name"' in potential_json
|
||||||
|
and '"arguments"' in potential_json):
|
||||||
|
search_text = potential_json
|
||||||
|
|
||||||
|
# Try to find complete tool names first
|
||||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||||
name_matches = list(re.finditer(name_pattern, current_text))
|
name_matches = list(re.finditer(name_pattern, search_text))
|
||||||
tool_count = len(name_matches)
|
tool_count = len(name_matches)
|
||||||
|
|
||||||
# If no tools found yet, return
|
# If no complete tool names found, check for partial tool names
|
||||||
if tool_count == 0:
|
if tool_count == 0:
|
||||||
return None
|
# Check if we're in the middle of parsing a tool name
|
||||||
|
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
|
||||||
|
partial_matches = list(
|
||||||
|
re.finditer(partial_name_pattern, search_text))
|
||||||
|
if partial_matches:
|
||||||
|
# We have a partial tool name - not ready to emit yet
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# No tools found at all
|
||||||
|
return None
|
||||||
|
|
||||||
# Ensure our state arrays are large enough
|
# Ensure our state arrays are large enough
|
||||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||||
@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
|
|||||||
# First, check for the empty arguments case: "arguments": {}
|
# First, check for the empty arguments case: "arguments": {}
|
||||||
empty_args_pattern = (
|
empty_args_pattern = (
|
||||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
|
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
|
||||||
empty_args_match = re.search(empty_args_pattern, current_text)
|
empty_args_match = re.search(empty_args_pattern, search_text)
|
||||||
|
|
||||||
# Check if this tool has empty arguments
|
# Check if this tool has empty arguments
|
||||||
if empty_args_match and empty_args_match.start() > 0:
|
if empty_args_match and empty_args_match.start() > 0:
|
||||||
@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
|
|||||||
|
|
||||||
# Extract arguments for current tool using regex for non-empty arguments
|
# Extract arguments for current tool using regex for non-empty arguments
|
||||||
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||||
args_matches = list(re.finditer(args_pattern, current_text))
|
args_matches = list(re.finditer(args_pattern, search_text))
|
||||||
|
|
||||||
if current_idx < len(args_matches):
|
if current_idx < len(args_matches):
|
||||||
args_text = args_matches[current_idx].group(1)
|
args_text = args_matches[current_idx].group(1)
|
||||||
@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
|
|||||||
# Handle transition between tools
|
# Handle transition between tools
|
||||||
is_last_tool = current_idx == tool_count - 1
|
is_last_tool = current_idx == tool_count - 1
|
||||||
|
|
||||||
# Find where the arguments for our current tool end
|
# For multiple tools, extract only the arguments for the current tool
|
||||||
if not is_last_tool:
|
if tool_count > 1:
|
||||||
# If we have more tools after this one, try to find the complete argument block
|
# Parse the entire JSON structure to properly extract arguments for each tool
|
||||||
next_tool_pos = current_text.find(
|
try:
|
||||||
"},{", args_matches[current_idx].start())
|
parsed_tools = json.loads(search_text)
|
||||||
if next_tool_pos != -1:
|
if isinstance(
|
||||||
args_end_pos = (next_tool_pos + 1
|
parsed_tools,
|
||||||
) # +1 to include the '}'
|
list) and current_idx < len(parsed_tools):
|
||||||
args_text = (current_text[args_matches[current_idx]
|
current_tool = parsed_tools[current_idx]
|
||||||
.start():args_end_pos].
|
if isinstance(current_tool.get("arguments"),
|
||||||
split('"arguments":')[1].strip())
|
dict):
|
||||||
|
args_text = json.dumps(
|
||||||
|
current_tool["arguments"])
|
||||||
|
else:
|
||||||
|
args_text = str(
|
||||||
|
current_tool.get("arguments", "{}"))
|
||||||
|
except (json.JSONDecodeError, KeyError, IndexError):
|
||||||
|
# Fallback to regex-based extraction
|
||||||
|
pass
|
||||||
|
|
||||||
# If arguments haven't been sent yet
|
# If arguments haven't been sent yet
|
||||||
sent_args = self.streaming_state["sent_tools"][
|
sent_args = self.streaming_state["sent_tools"][
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user