From 68373d3126b4d2c49a9983fe0696bbd48fc8aad7 Mon Sep 17 00:00:00 2001 From: Woonggi Min Date: Sun, 17 Aug 2025 02:38:42 +0900 Subject: [PATCH] [Frontend] Added support for HermesToolParser for models without special tokens (#16890) Signed-off-by: minpeter --- .../tool_parsers/test_hermes_tool_parser.py | 127 ++++++++++++++++++ .../openai/tool_parsers/hermes_tool_parser.py | 81 ++++++++--- 2 files changed, 191 insertions(+), 17 deletions(-) create mode 100644 tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py new file mode 100644 index 0000000000000..28b1f8358d80b --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from ....utils import RemoteOpenAIServer + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci" + +SERVER_ARGS = [ + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enable-lora", + "--lora-modules", + f"{LORA_MODEL}={LORA_MODEL}", +] + +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + }, + }, + "required": ["location"], + }, + }, +}] + +MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] + + +@pytest.mark.asyncio +async def test_non_streaming_tool_call(): + """Test tool call in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_current_weather" + + arguments = json.loads(tool_call.function.arguments) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Non-Streaming Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_tool_call(): + """Test tool call in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index][ + "arguments"] += tool_chunk.function.arguments + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_current_weather" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Streaming Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index c7030d34d453e..d126130ab9bc3 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -52,14 +52,51 @@ class Hermes2ProToolParser(ToolParser): raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) - self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): - raise RuntimeError( - "Hermes 2 Pro Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + self.tool_call_start_token_ids = self.model_tokenizer.encode( + self.tool_call_start_token, add_special_tokens=False) + self.tool_call_end_token_ids = self.model_tokenizer.encode( + self.tool_call_end_token, add_special_tokens=False) + + self.tool_call_start_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_start_token_ids + ] + + self.tool_call_end_token_array = [ + self.model_tokenizer.decode([token_id]) + for token_id in self.tool_call_end_token_ids + ] + + self.buffered_delta_text = "" + + # Very simple idea: when encountering tokens like <, tool, _call, >, + # <, /, tool, _call, >, store them in a buffer. + # When the last token is encountered, empty the buffer and return it. + # If a token appears in an incorrect sequence while storing in the buffer, + # return the preceding buffer along with the token. + def tool_call_delta_buffer(self, delta_text: str): + # If the sequence of tool_call_start or tool_call_end tokens is not yet + # complete, fill the buffer with the token and return "". + if (delta_text in self.tool_call_start_token_array + or delta_text in self.tool_call_end_token_array): + # If delta_text is the last token of tool_call_start_token or + # tool_call_end_token, empty the buffer and return + # the buffered text + delta_text. + if (delta_text == self.tool_call_start_token_array[-1] + or delta_text == self.tool_call_end_token_array[-1]): + buffered_text = self.buffered_delta_text + self.buffered_delta_text = "" + return buffered_text + delta_text + else: + self.buffered_delta_text = self.buffered_delta_text + delta_text + return "" + else: + if self.buffered_delta_text: + buffered_text = self.buffered_delta_text + self.buffered_delta_text = "" + return buffered_text + delta_text + else: + return delta_text def extract_tool_calls( self, @@ -124,11 +161,23 @@ class Hermes2ProToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: + # 1. All tokens are parsed based on _text, not token_ids. + # 2. All incoming text data is processed by the tool_call_delta_buffer + # function for buffering before being used for parsing. + + delta_text = self.tool_call_delta_buffer(delta_text) + # If the last characters of previous_text + # match self.buffered_delta_text, remove only the matching part. + if (len(previous_text) >= len(self.buffered_delta_text) + and previous_text[-len(self.buffered_delta_text):] + == self.buffered_delta_text): + previous_text = previous_text[:-len(self.buffered_delta_text)] + current_text = previous_text + delta_text logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token_id not in current_token_ids: + if self.tool_call_start_token not in current_text: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) @@ -136,14 +185,12 @@ class Hermes2ProToolParser(ToolParser): # figure out where we are in the parsing by counting tool call # start & end tags - prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + prev_tool_start_count = previous_text.count( + self.tool_call_start_token) + prev_tool_end_count = previous_text.count(self.tool_call_end_token) + cur_tool_start_count = current_text.count( + self.tool_call_start_token) + cur_tool_end_count = current_text.count(self.tool_call_end_token) tool_call_portion = None text_portion = None