diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 4bab849f47c2..e0e6b2c07e17 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -45,8 +45,39 @@ TOOLS = [{ }, }] +PRODUCT_TOOLS = [{ + "type": "function", + "function": { + "name": "get_product_info", + "description": "Get detailed information of a product based on its " + "product ID.", + "parameters": { + "type": "object", + "properties": { + "inserted": { + "type": "boolean", + "description": "inserted.", + }, + "product_id": { + "type": "integer", + "description": "The product ID of the product.", + }, + }, + "required": ["product_id", "inserted"], + }, + }, +}] + MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] +PRODUCT_MESSAGES = [{ + "role": + "user", + "content": + "Hi! Do you have any detailed information about the product id " + "7355608 and inserted true?" +}] + @pytest.mark.asyncio async def test_non_streaming_tool_call(): @@ -127,3 +158,103 @@ async def test_streaming_tool_call(): print("\n[Streaming Test Passed]") print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") print(f"Reconstructed Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_non_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_product_info" + + arguments = json.loads(tool_call.function.arguments) + assert "product_id" in arguments + assert "inserted" in arguments + + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Non-Streaming Product Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index][ + "arguments"] += tool_chunk.function.arguments + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_product_info" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "product_id" in arguments + assert "inserted" in arguments + + # Handle type coercion for streaming test as well + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Streaming Product Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index e74c420da1d3..87595953da06 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -368,16 +368,32 @@ class Hermes2ProToolParser(ToolParser): # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: + # extract the content after {"name": ..., "arguments": + # directly from tool_call_portion as cur_arguments_json, + # since cur_arguments may differ from the original text + # due to partial JSON parsing + # for example, tool_call_portion = + # {"name": "search", "arguments": {"search_request": {" + # but cur_arguments = + # {"search_request": {}} + function_name = current_tool_call.get("name") + match = re.search( + r'\{"name":\s*"' + + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', + tool_call_portion.strip(), re.DOTALL) + if match: + cur_arguments_json = match.group(1) + else: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) logger.debug("finding %s in %s", delta_text, cur_arguments_json) - # get the location where previous args differ from current - if (delta_text not in cur_arguments_json[:-2]): + # get the location where previous args differ from current. + if (delta_text not in cur_arguments_json): return None - args_delta_start_loc = cur_arguments_json[:-2]. \ + args_delta_start_loc = cur_arguments_json. \ rindex(delta_text) + \ len(delta_text) @@ -397,8 +413,20 @@ class Hermes2ProToolParser(ToolParser): # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if isinstance(delta_text, str) and len(delta_text.rstrip( - )) >= 1 and delta_text.rstrip()[-1] == '}': + # judge whether the tool_call_portion is a complete JSON + try: + json.loads(tool_call_portion) + is_complete_json = True + except Exception: + is_complete_json = False + + # if the delta_text ends with a '}' and tool_call_portion is a + # complete JSON, then the last '}' does not belong to the + # arguments, so we should trim it off + if isinstance(delta_text, str) \ + and len(delta_text.rstrip()) >= 1 \ + and delta_text.rstrip()[-1] == '}' \ + and is_complete_json: delta_text = delta_text.rstrip()[:-1] logger.debug("got diff %s", delta_text)