mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 11:45:57 +08:00
[Bugfix] Fix hermes tool parser handling of non-string argument types (#22002)
Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: David Chen <530634352@qq.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
parent
793be8d057
commit
0eecb31663
@ -45,8 +45,39 @@ TOOLS = [{
|
|||||||
},
|
},
|
||||||
}]
|
}]
|
||||||
|
|
||||||
|
PRODUCT_TOOLS = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_product_info",
|
||||||
|
"description": "Get detailed information of a product based on its "
|
||||||
|
"product ID.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"inserted": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "inserted.",
|
||||||
|
},
|
||||||
|
"product_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The product ID of the product.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["product_id", "inserted"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
|
||||||
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||||
|
|
||||||
|
PRODUCT_MESSAGES = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Hi! Do you have any detailed information about the product id "
|
||||||
|
"7355608 and inserted true?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_streaming_tool_call():
|
async def test_non_streaming_tool_call():
|
||||||
@ -127,3 +158,103 @@ async def test_streaming_tool_call():
|
|||||||
print("\n[Streaming Test Passed]")
|
print("\n[Streaming Test Passed]")
|
||||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||||
print(f"Reconstructed Arguments: {arguments}")
|
print(f"Reconstructed Arguments: {arguments}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_streaming_product_tool_call():
|
||||||
|
"""Test tool call integer and boolean parameters in non-streaming mode."""
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||||
|
client = server.get_async_client()
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=LORA_MODEL,
|
||||||
|
messages=PRODUCT_MESSAGES,
|
||||||
|
tools=PRODUCT_TOOLS,
|
||||||
|
tool_choice="auto",
|
||||||
|
temperature=0.66,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices
|
||||||
|
choice = response.choices[0]
|
||||||
|
message = choice.message
|
||||||
|
|
||||||
|
assert choice.finish_reason == "tool_calls"
|
||||||
|
assert message.tool_calls is not None
|
||||||
|
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
assert tool_call.type == "function"
|
||||||
|
assert tool_call.function.name == "get_product_info"
|
||||||
|
|
||||||
|
arguments = json.loads(tool_call.function.arguments)
|
||||||
|
assert "product_id" in arguments
|
||||||
|
assert "inserted" in arguments
|
||||||
|
|
||||||
|
product_id = arguments.get("product_id")
|
||||||
|
inserted = arguments.get("inserted")
|
||||||
|
|
||||||
|
assert isinstance(product_id, int)
|
||||||
|
assert product_id == 7355608
|
||||||
|
assert isinstance(inserted, bool)
|
||||||
|
assert inserted is True
|
||||||
|
|
||||||
|
print("\n[Non-Streaming Product Test Passed]")
|
||||||
|
print(f"Tool Call: {tool_call.function.name}")
|
||||||
|
print(f"Arguments: {arguments}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_product_tool_call():
|
||||||
|
"""Test tool call integer and boolean parameters in streaming mode."""
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||||
|
client = server.get_async_client()
|
||||||
|
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=LORA_MODEL,
|
||||||
|
messages=PRODUCT_MESSAGES,
|
||||||
|
tools=PRODUCT_TOOLS,
|
||||||
|
tool_choice="auto",
|
||||||
|
temperature=0.66,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call_chunks = {}
|
||||||
|
async for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if not delta or not delta.tool_calls:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for tool_chunk in delta.tool_calls:
|
||||||
|
index = tool_chunk.index
|
||||||
|
if index not in tool_call_chunks:
|
||||||
|
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||||
|
|
||||||
|
if tool_chunk.function.name:
|
||||||
|
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||||
|
if tool_chunk.function.arguments:
|
||||||
|
tool_call_chunks[index][
|
||||||
|
"arguments"] += tool_chunk.function.arguments
|
||||||
|
|
||||||
|
assert len(tool_call_chunks) == 1
|
||||||
|
reconstructed_tool_call = tool_call_chunks[0]
|
||||||
|
|
||||||
|
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||||
|
|
||||||
|
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||||
|
assert "product_id" in arguments
|
||||||
|
assert "inserted" in arguments
|
||||||
|
|
||||||
|
# Handle type coercion for streaming test as well
|
||||||
|
product_id = arguments.get("product_id")
|
||||||
|
inserted = arguments.get("inserted")
|
||||||
|
|
||||||
|
assert isinstance(product_id, int)
|
||||||
|
assert product_id == 7355608
|
||||||
|
assert isinstance(inserted, bool)
|
||||||
|
assert inserted is True
|
||||||
|
|
||||||
|
print("\n[Streaming Product Test Passed]")
|
||||||
|
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||||
|
print(f"Reconstructed Arguments: {arguments}")
|
||||||
|
|||||||
@ -368,16 +368,32 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
# case -- we now have the first info about arguments available from
|
# case -- we now have the first info about arguments available from
|
||||||
# autocompleting the JSON
|
# autocompleting the JSON
|
||||||
elif cur_arguments and not prev_arguments:
|
elif cur_arguments and not prev_arguments:
|
||||||
|
# extract the content after {"name": ..., "arguments":
|
||||||
|
# directly from tool_call_portion as cur_arguments_json,
|
||||||
|
# since cur_arguments may differ from the original text
|
||||||
|
# due to partial JSON parsing
|
||||||
|
# for example, tool_call_portion =
|
||||||
|
# {"name": "search", "arguments": {"search_request": {"
|
||||||
|
# but cur_arguments =
|
||||||
|
# {"search_request": {}}
|
||||||
|
function_name = current_tool_call.get("name")
|
||||||
|
match = re.search(
|
||||||
|
r'\{"name":\s*"' +
|
||||||
|
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
||||||
|
tool_call_portion.strip(), re.DOTALL)
|
||||||
|
if match:
|
||||||
|
cur_arguments_json = match.group(1)
|
||||||
|
else:
|
||||||
cur_arguments_json = json.dumps(cur_arguments,
|
cur_arguments_json = json.dumps(cur_arguments,
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
|
|
||||||
logger.debug("finding %s in %s", delta_text,
|
logger.debug("finding %s in %s", delta_text,
|
||||||
cur_arguments_json)
|
cur_arguments_json)
|
||||||
|
|
||||||
# get the location where previous args differ from current
|
# get the location where previous args differ from current.
|
||||||
if (delta_text not in cur_arguments_json[:-2]):
|
if (delta_text not in cur_arguments_json):
|
||||||
return None
|
return None
|
||||||
args_delta_start_loc = cur_arguments_json[:-2]. \
|
args_delta_start_loc = cur_arguments_json. \
|
||||||
rindex(delta_text) + \
|
rindex(delta_text) + \
|
||||||
len(delta_text)
|
len(delta_text)
|
||||||
|
|
||||||
@ -397,8 +413,20 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
|
|
||||||
# last case -- we have an update to existing arguments.
|
# last case -- we have an update to existing arguments.
|
||||||
elif cur_arguments and prev_arguments:
|
elif cur_arguments and prev_arguments:
|
||||||
if isinstance(delta_text, str) and len(delta_text.rstrip(
|
# judge whether the tool_call_portion is a complete JSON
|
||||||
)) >= 1 and delta_text.rstrip()[-1] == '}':
|
try:
|
||||||
|
json.loads(tool_call_portion)
|
||||||
|
is_complete_json = True
|
||||||
|
except Exception:
|
||||||
|
is_complete_json = False
|
||||||
|
|
||||||
|
# if the delta_text ends with a '}' and tool_call_portion is a
|
||||||
|
# complete JSON, then the last '}' does not belong to the
|
||||||
|
# arguments, so we should trim it off
|
||||||
|
if isinstance(delta_text, str) \
|
||||||
|
and len(delta_text.rstrip()) >= 1 \
|
||||||
|
and delta_text.rstrip()[-1] == '}' \
|
||||||
|
and is_complete_json:
|
||||||
delta_text = delta_text.rstrip()[:-1]
|
delta_text = delta_text.rstrip()[:-1]
|
||||||
|
|
||||||
logger.debug("got diff %s", delta_text)
|
logger.debug("got diff %s", delta_text)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user