mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 03:24:28 +08:00
[Frontend][Feature] Add jamba tool parser (#9154)
This commit is contained in:
parent
1ffc8a7362
commit
d2b1bf55ec
@ -157,7 +157,7 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
|
||||
To enable this feature, you should set the following flags:
|
||||
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
||||
deems appropriate.
|
||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
|
||||
* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers
|
||||
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
|
||||
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
|
||||
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
||||
@ -168,7 +168,7 @@ from HuggingFace; and you can find an example of this in a `tokenizer_config.jso
|
||||
|
||||
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
|
||||
|
||||
#### Hermes Models
|
||||
#### Hermes Models (`hermes`)
|
||||
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
|
||||
* `NousResearch/Hermes-2-Pro-*`
|
||||
* `NousResearch/Hermes-2-Theta-*`
|
||||
@ -180,7 +180,7 @@ step in their creation_.
|
||||
|
||||
Flags: `--tool-call-parser hermes`
|
||||
|
||||
#### Mistral Models
|
||||
#### Mistral Models (`mistral`)
|
||||
Supported models:
|
||||
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
|
||||
* Additional mistral function-calling models are compatible as well.
|
||||
@ -199,7 +199,7 @@ when tools are provided, that results in much better reliability when working wi
|
||||
|
||||
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
|
||||
|
||||
#### Llama Models
|
||||
#### Llama Models (`llama3_json`)
|
||||
Supported models:
|
||||
* `meta-llama/Meta-Llama-3.1-8B-Instruct`
|
||||
* `meta-llama/Meta-Llama-3.1-70B-Instruct`
|
||||
@ -219,16 +219,24 @@ it works better with vLLM.
|
||||
|
||||
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
|
||||
|
||||
#### Internlm Models
|
||||
#### InternLM Models (`internlm`)
|
||||
Supported models:
|
||||
* `internlm/internlm2_5-7b-chat` (confirmed)
|
||||
* Additional internlm2.5 function-calling models are compatible as well
|
||||
|
||||
Known issues:
|
||||
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.
|
||||
* Although this implementation also supports InternLM2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.
|
||||
|
||||
Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`
|
||||
|
||||
#### Jamba Models (`jamba`)
|
||||
AI21's Jamba-1.5 models are supported.
|
||||
* `ai21labs/AI21-Jamba-1.5-Mini`
|
||||
* `ai21labs/AI21-Jamba-1.5-Large`
|
||||
|
||||
|
||||
Flags: `--tool-call-parser jamba`
|
||||
|
||||
|
||||
### How to write a tool parser plugin
|
||||
|
||||
|
||||
275
tests/tool_use/test_jamba_tool_parser.py
Normal file
275
tests/tool_use/test_jamba_tool_parser.py
Normal file
@ -0,0 +1,275 @@
|
||||
import json
|
||||
from typing import Generator, List, Optional
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
MODEL = "ai21labs/Jamba-tiny-dev"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def jamba_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jamba_tool_parser(jamba_tokenizer):
|
||||
return JambaToolParser(jamba_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: List[ToolCall],
|
||||
expected_tool_calls: List[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer,
|
||||
model_output: str) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = jamba_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=jamba_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = jamba_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=None, # type: ignore[arg-type]
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = previous_tokens + new_tokens if previous_tokens\
|
||||
else new_tokens
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(jamba_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool",
|
||||
"single_tool_with_content",
|
||||
"parallel_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
],
|
||||
None),
|
||||
(
|
||||
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
],
|
||||
" Sure! let me call the tool for you."),
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
}))),
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
],
|
||||
None)
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(jamba_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool",
|
||||
"single_tool_with_content",
|
||||
"parallel_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
('''This is a test''', [], '''This is a test'''),
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
],
|
||||
" "),
|
||||
(
|
||||
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
],
|
||||
" Sure! let me call the tool for you."),
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
}))),
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
],
|
||||
" ")
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
|
||||
model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
other_content: str = ''
|
||||
function_names: List[str] = []
|
||||
function_args_strs: List[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_ids: List[Optional[str]] = []
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
jamba_tool_parser, jamba_tokenizer, model_output):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
streamed_tool_calls = delta_message.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
function_args_strs.append("")
|
||||
tool_call_ids.append(None)
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id and not tool_call_ids[tool_call.index]:
|
||||
tool_call_ids[tool_call.index] = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
function_args_strs[
|
||||
tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert other_content == expected_content
|
||||
|
||||
actual_tool_calls = [
|
||||
ToolCall(id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR)))
|
||||
for tool_call_id, function_name, function_args_str in zip(
|
||||
tool_call_ids, function_names, function_args_strs)
|
||||
]
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
@ -1,10 +1,12 @@
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .jamba_tool_parser import JambaToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
|
||||
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
|
||||
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",
|
||||
"JambaToolParser"
|
||||
]
|
||||
|
||||
@ -53,7 +53,8 @@ class Hermes2ProToolParser(ToolParser):
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
300
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Normal file
300
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Normal file
@ -0,0 +1,300 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("jamba")
|
||||
class JambaToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"Detected a MistralTokenizer tokenizer when using a Jamba model"
|
||||
)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_calls_start_token: str = "<tool_calls>"
|
||||
self.tool_calls_end_token: str = "</tool_calls>"
|
||||
|
||||
self.tool_calls_regex = re.compile(
|
||||
rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
|
||||
re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token)
|
||||
if (self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Jamba Tool parser could not locate tool calls start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because jamba use the special
|
||||
# tokens to indicate the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# use a regex to find the tool call between the tags
|
||||
function_calls = self.tool_calls_regex.findall(model_output)[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = json.loads(function_calls)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"])))
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if
|
||||
(len(content) > 0 and content != " ") else None)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.tool_calls_start_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the start of tool calls token which means
|
||||
# the start of tool calling
|
||||
if (self.tool_calls_start_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion and don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# Extract the tool calls between the special tool call tokens
|
||||
parsable_arr = current_text.split(
|
||||
self.tool_calls_start_token)[-1].split(
|
||||
self.tool_calls_end_token)[0]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: List[Dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
@ -63,7 +63,7 @@ class MistralToolParser(ToolParser):
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
if not self.bot_token_id:
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user