mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 11:45:39 +08:00
[Test]: Hermes tool parser stream output error in Qwen3 case (#25203)
Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com>
This commit is contained in:
parent
babad6e5dd
commit
4322c553a6
@ -5,6 +5,11 @@ import json
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
|
||||||
|
Hermes2ProToolParser)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
from ....utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
@ -37,7 +42,7 @@ TOOLS = [{
|
|||||||
},
|
},
|
||||||
"unit": {
|
"unit": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["celsius", "fahrenheit"]
|
"enum": ["celsius", "fahrenheit"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location"],
|
"required": ["location"],
|
||||||
@ -75,7 +80,7 @@ PRODUCT_MESSAGES = [{
|
|||||||
"user",
|
"user",
|
||||||
"content":
|
"content":
|
||||||
"Hi! Do you have any detailed information about the product id "
|
"Hi! Do you have any detailed information about the product id "
|
||||||
"7355608 and inserted true?"
|
"7355608 and inserted true?",
|
||||||
}]
|
}]
|
||||||
|
|
||||||
|
|
||||||
@ -144,8 +149,8 @@ async def test_streaming_tool_call():
|
|||||||
if tool_chunk.function.name:
|
if tool_chunk.function.name:
|
||||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||||
if tool_chunk.function.arguments:
|
if tool_chunk.function.arguments:
|
||||||
tool_call_chunks[index][
|
tool_call_chunks[index]["arguments"] += (
|
||||||
"arguments"] += tool_chunk.function.arguments
|
tool_chunk.function.arguments)
|
||||||
|
|
||||||
assert len(tool_call_chunks) == 1
|
assert len(tool_call_chunks) == 1
|
||||||
reconstructed_tool_call = tool_call_chunks[0]
|
reconstructed_tool_call = tool_call_chunks[0]
|
||||||
@ -234,8 +239,8 @@ async def test_streaming_product_tool_call():
|
|||||||
if tool_chunk.function.name:
|
if tool_chunk.function.name:
|
||||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||||
if tool_chunk.function.arguments:
|
if tool_chunk.function.arguments:
|
||||||
tool_call_chunks[index][
|
tool_call_chunks[index]["arguments"] += (
|
||||||
"arguments"] += tool_chunk.function.arguments
|
tool_chunk.function.arguments)
|
||||||
|
|
||||||
assert len(tool_call_chunks) == 1
|
assert len(tool_call_chunks) == 1
|
||||||
reconstructed_tool_call = tool_call_chunks[0]
|
reconstructed_tool_call = tool_call_chunks[0]
|
||||||
@ -258,3 +263,195 @@ async def test_streaming_product_tool_call():
|
|||||||
print("\n[Streaming Product Test Passed]")
|
print("\n[Streaming Product Test Passed]")
|
||||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||||
print(f"Reconstructed Arguments: {arguments}")
|
print(f"Reconstructed Arguments: {arguments}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def qwen_tokenizer() -> AnyTokenizer:
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
return get_tokenizer("Qwen/Qwen3-32B")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
|
||||||
|
return Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def any_chat_request() -> ChatCompletionRequest:
|
||||||
|
return ChatCompletionRequest(
|
||||||
|
seed=42,
|
||||||
|
model="Qwen/Qwen3-32B",
|
||||||
|
messages=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_streaming_just_forward_text(
|
||||||
|
qwen_tokenizer: AnyTokenizer,
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
text = (
|
||||||
|
"""This is some prior text that has nothing to do with tool calling."""
|
||||||
|
)
|
||||||
|
tokens = qwen_tokenizer.encode(text)
|
||||||
|
previous_text = ""
|
||||||
|
delta_messages = []
|
||||||
|
for token in tokens:
|
||||||
|
delta_text = qwen_tokenizer.decode([token])
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
delta = hermes_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
previous_text = current_text
|
||||||
|
delta_messages.append(delta)
|
||||||
|
|
||||||
|
for delta in delta_messages:
|
||||||
|
assert delta is not None
|
||||||
|
assert not delta.tool_calls
|
||||||
|
|
||||||
|
print(delta_messages)
|
||||||
|
assert "".join([delta.content for delta in delta_messages]) == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||||
|
qwen_tokenizer: AnyTokenizer,
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
text = """<tool_call>
|
||||||
|
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||||
|
</tool_call>"""
|
||||||
|
tokens = qwen_tokenizer.encode(text)
|
||||||
|
previous_text = ""
|
||||||
|
delta_messages = []
|
||||||
|
for token in tokens:
|
||||||
|
text = qwen_tokenizer.decode([token])
|
||||||
|
current_text = previous_text + text
|
||||||
|
delta = hermes_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
previous_text = current_text
|
||||||
|
if delta is not None:
|
||||||
|
delta_messages.append(delta)
|
||||||
|
|
||||||
|
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||||
|
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||||
|
for delta in delta_messages)
|
||||||
|
assert tool_call_args == '{"trigger": true}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_streaming(
|
||||||
|
qwen_tokenizer: AnyTokenizer,
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
text = '<tool_call>\
|
||||||
|
{"name": "get_current_temperature",\
|
||||||
|
"arguments": {"location":\
|
||||||
|
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||||
|
</tool_call>'
|
||||||
|
|
||||||
|
tokens = qwen_tokenizer.encode(text)
|
||||||
|
previous_text = ""
|
||||||
|
delta_messages = []
|
||||||
|
for token in tokens:
|
||||||
|
text = qwen_tokenizer.decode([token])
|
||||||
|
current_text = previous_text + text
|
||||||
|
delta = hermes_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[],
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
previous_text = current_text
|
||||||
|
if delta is not None:
|
||||||
|
delta_messages.append(delta)
|
||||||
|
print(delta_messages)
|
||||||
|
assert (delta_messages[0].tool_calls[0].function.name ==
|
||||||
|
"get_current_temperature")
|
||||||
|
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||||
|
for delta in delta_messages)
|
||||||
|
assert tool_call_args == (
|
||||||
|
'{"location":"San Francisco, California, United States", '
|
||||||
|
'"unit": "celsius"}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_non_streaming_no_tool_call(
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
text = """This is not a tool call."""
|
||||||
|
tool_call = hermes_parser.extract_tool_calls(
|
||||||
|
model_output=text,
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_call is not None
|
||||||
|
assert not tool_call.tools_called
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
text = """<tool_call>
|
||||||
|
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||||
|
</tool_call>"""
|
||||||
|
tool_call = hermes_parser.extract_tool_calls(
|
||||||
|
model_output=text,
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_call is not None
|
||||||
|
assert tool_call.tools_called
|
||||||
|
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||||
|
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
text = """<tool_call>
|
||||||
|
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||||
|
tool_call = hermes_parser.extract_tool_calls(
|
||||||
|
model_output=text,
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_call is not None
|
||||||
|
assert tool_call.tools_called
|
||||||
|
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||||
|
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||||
|
hermes_parser: Hermes2ProToolParser,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
# Missing closing brace to trigger exception
|
||||||
|
text = """<tool_call>
|
||||||
|
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||||
|
tool_call = hermes_parser.extract_tool_calls(
|
||||||
|
model_output=text,
|
||||||
|
request=any_chat_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_call is not None
|
||||||
|
assert not tool_call.tools_called
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user