mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 11:02:16 +08:00
[Frontend] Added support for HermesToolParser for models without special tokens (#16890)
Signed-off-by: minpeter <kali2005611@gmail.com>
This commit is contained in:
parent
52ce1420e9
commit
68373d3126
127
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
Normal file
127
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ....utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"
|
||||||
|
|
||||||
|
SERVER_ARGS = [
|
||||||
|
"--enforce-eager",
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
"--tool-call-parser",
|
||||||
|
"hermes",
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
f"{LORA_MODEL}={LORA_MODEL}",
|
||||||
|
]
|
||||||
|
|
||||||
|
TOOLS = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description":
|
||||||
|
"The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
|
||||||
|
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_streaming_tool_call():
|
||||||
|
"""Test tool call in non-streaming mode."""
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||||
|
client = server.get_async_client()
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=LORA_MODEL,
|
||||||
|
messages=MESSAGES,
|
||||||
|
tools=TOOLS,
|
||||||
|
tool_choice="auto",
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices
|
||||||
|
choice = response.choices[0]
|
||||||
|
message = choice.message
|
||||||
|
|
||||||
|
assert choice.finish_reason == "tool_calls"
|
||||||
|
assert message.tool_calls is not None
|
||||||
|
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
assert tool_call.type == "function"
|
||||||
|
assert tool_call.function.name == "get_current_weather"
|
||||||
|
|
||||||
|
arguments = json.loads(tool_call.function.arguments)
|
||||||
|
assert "location" in arguments
|
||||||
|
assert "Boston" in arguments["location"]
|
||||||
|
print("\n[Non-Streaming Test Passed]")
|
||||||
|
print(f"Tool Call: {tool_call.function.name}")
|
||||||
|
print(f"Arguments: {arguments}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_tool_call():
|
||||||
|
"""Test tool call in streaming mode."""
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||||
|
client = server.get_async_client()
|
||||||
|
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=LORA_MODEL,
|
||||||
|
messages=MESSAGES,
|
||||||
|
tools=TOOLS,
|
||||||
|
tool_choice="auto",
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call_chunks = {}
|
||||||
|
async for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if not delta or not delta.tool_calls:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for tool_chunk in delta.tool_calls:
|
||||||
|
index = tool_chunk.index
|
||||||
|
if index not in tool_call_chunks:
|
||||||
|
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||||
|
|
||||||
|
if tool_chunk.function.name:
|
||||||
|
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||||
|
if tool_chunk.function.arguments:
|
||||||
|
tool_call_chunks[index][
|
||||||
|
"arguments"] += tool_chunk.function.arguments
|
||||||
|
|
||||||
|
assert len(tool_call_chunks) == 1
|
||||||
|
reconstructed_tool_call = tool_call_chunks[0]
|
||||||
|
|
||||||
|
assert reconstructed_tool_call["name"] == "get_current_weather"
|
||||||
|
|
||||||
|
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||||
|
assert "location" in arguments
|
||||||
|
assert "Boston" in arguments["location"]
|
||||||
|
print("\n[Streaming Test Passed]")
|
||||||
|
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||||
|
print(f"Reconstructed Arguments: {arguments}")
|
||||||
@ -52,14 +52,51 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The model tokenizer must be passed to the ToolParser "
|
"The model tokenizer must be passed to the ToolParser "
|
||||||
"constructor during construction.")
|
"constructor during construction.")
|
||||||
self.tool_call_start_token_id = self.vocab.get(
|
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||||
self.tool_call_start_token)
|
self.tool_call_start_token, add_special_tokens=False)
|
||||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||||
if (self.tool_call_start_token_id is None
|
self.tool_call_end_token, add_special_tokens=False)
|
||||||
or self.tool_call_end_token_id is None):
|
|
||||||
raise RuntimeError(
|
self.tool_call_start_token_array = [
|
||||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
self.model_tokenizer.decode([token_id])
|
||||||
"tokens in the tokenizer!")
|
for token_id in self.tool_call_start_token_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
self.tool_call_end_token_array = [
|
||||||
|
self.model_tokenizer.decode([token_id])
|
||||||
|
for token_id in self.tool_call_end_token_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
self.buffered_delta_text = ""
|
||||||
|
|
||||||
|
# Very simple idea: when encountering tokens like <, tool, _call, >,
|
||||||
|
# <, /, tool, _call, >, store them in a buffer.
|
||||||
|
# When the last token is encountered, empty the buffer and return it.
|
||||||
|
# If a token appears in an incorrect sequence while storing in the buffer,
|
||||||
|
# return the preceding buffer along with the token.
|
||||||
|
def tool_call_delta_buffer(self, delta_text: str):
|
||||||
|
# If the sequence of tool_call_start or tool_call_end tokens is not yet
|
||||||
|
# complete, fill the buffer with the token and return "".
|
||||||
|
if (delta_text in self.tool_call_start_token_array
|
||||||
|
or delta_text in self.tool_call_end_token_array):
|
||||||
|
# If delta_text is the last token of tool_call_start_token or
|
||||||
|
# tool_call_end_token, empty the buffer and return
|
||||||
|
# the buffered text + delta_text.
|
||||||
|
if (delta_text == self.tool_call_start_token_array[-1]
|
||||||
|
or delta_text == self.tool_call_end_token_array[-1]):
|
||||||
|
buffered_text = self.buffered_delta_text
|
||||||
|
self.buffered_delta_text = ""
|
||||||
|
return buffered_text + delta_text
|
||||||
|
else:
|
||||||
|
self.buffered_delta_text = self.buffered_delta_text + delta_text
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
if self.buffered_delta_text:
|
||||||
|
buffered_text = self.buffered_delta_text
|
||||||
|
self.buffered_delta_text = ""
|
||||||
|
return buffered_text + delta_text
|
||||||
|
else:
|
||||||
|
return delta_text
|
||||||
|
|
||||||
def extract_tool_calls(
|
def extract_tool_calls(
|
||||||
self,
|
self,
|
||||||
@ -124,11 +161,23 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
|
# 1. All tokens are parsed based on _text, not token_ids.
|
||||||
|
# 2. All incoming text data is processed by the tool_call_delta_buffer
|
||||||
|
# function for buffering before being used for parsing.
|
||||||
|
|
||||||
|
delta_text = self.tool_call_delta_buffer(delta_text)
|
||||||
|
# If the last characters of previous_text
|
||||||
|
# match self.buffered_delta_text, remove only the matching part.
|
||||||
|
if (len(previous_text) >= len(self.buffered_delta_text)
|
||||||
|
and previous_text[-len(self.buffered_delta_text):]
|
||||||
|
== self.buffered_delta_text):
|
||||||
|
previous_text = previous_text[:-len(self.buffered_delta_text)]
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
|
||||||
logger.debug("delta_text: %s", delta_text)
|
logger.debug("delta_text: %s", delta_text)
|
||||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||||
# check to see if we should be streaming a tool call - is there a
|
# check to see if we should be streaming a tool call - is there a
|
||||||
if self.tool_call_start_token_id not in current_token_ids:
|
if self.tool_call_start_token not in current_text:
|
||||||
logger.debug("No tool call tokens found!")
|
logger.debug("No tool call tokens found!")
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
@ -136,14 +185,12 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
|
|
||||||
# figure out where we are in the parsing by counting tool call
|
# figure out where we are in the parsing by counting tool call
|
||||||
# start & end tags
|
# start & end tags
|
||||||
prev_tool_start_count = previous_token_ids.count(
|
prev_tool_start_count = previous_text.count(
|
||||||
self.tool_call_start_token_id)
|
self.tool_call_start_token)
|
||||||
prev_tool_end_count = previous_token_ids.count(
|
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
|
||||||
self.tool_call_end_token_id)
|
cur_tool_start_count = current_text.count(
|
||||||
cur_tool_start_count = current_token_ids.count(
|
self.tool_call_start_token)
|
||||||
self.tool_call_start_token_id)
|
cur_tool_end_count = current_text.count(self.tool_call_end_token)
|
||||||
cur_tool_end_count = current_token_ids.count(
|
|
||||||
self.tool_call_end_token_id)
|
|
||||||
tool_call_portion = None
|
tool_call_portion = None
|
||||||
text_portion = None
|
text_portion = None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user