mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[Test]: Hermes tool parser stream output error in Qwen3 case (#25203)
Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com>
This commit is contained in:
parent
babad6e5dd
commit
4322c553a6
@ -5,6 +5,11 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
|
||||
Hermes2ProToolParser)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
@ -37,7 +42,7 @@ TOOLS = [{
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
@ -75,7 +80,7 @@ PRODUCT_MESSAGES = [{
|
||||
"user",
|
||||
"content":
|
||||
"Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?"
|
||||
"7355608 and inserted true?",
|
||||
}]
|
||||
|
||||
|
||||
@ -144,8 +149,8 @@ async def test_streaming_tool_call():
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index][
|
||||
"arguments"] += tool_chunk.function.arguments
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@ -234,8 +239,8 @@ async def test_streaming_product_tool_call():
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index][
|
||||
"arguments"] += tool_chunk.function.arguments
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@ -258,3 +263,195 @@ async def test_streaming_product_tool_call():
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> AnyTokenizer:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_chat_request() -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
seed=42,
|
||||
model="Qwen/Qwen3-32B",
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = (
|
||||
"""This is some prior text that has nothing to do with tool calling."""
|
||||
)
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
delta_text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + delta_text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
delta_messages.append(delta)
|
||||
|
||||
for delta in delta_messages:
|
||||
assert delta is not None
|
||||
assert not delta.tool_calls
|
||||
|
||||
print(delta_messages)
|
||||
assert "".join([delta.content for delta in delta_messages]) == text
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
{"name": "get_current_temperature",\
|
||||
"arguments": {"location":\
|
||||
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||
</tool_call>'
|
||||
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert (delta_messages[0].tool_calls[0].function.name ==
|
||||
"get_current_temperature")
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", '
|
||||
'"unit": "celsius"}')
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user