mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[Frontend] [gpt-oss] Tool json call parsing error retry (#27675)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
1891cf605a
commit
3c7fefdeba
@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Union
|
|||||||
from openai.types.responses.tool import Mcp
|
from openai.types.responses.tool import Mcp
|
||||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
get_encoding,
|
get_encoding,
|
||||||
get_streamable_parser_for_assistant,
|
get_streamable_parser_for_assistant,
|
||||||
@ -109,6 +110,28 @@ class ConversationContext(ABC):
|
|||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_json_parse_error_messages(
|
||||||
|
last_msg: Message, e: json.JSONDecodeError
|
||||||
|
) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Creates an error message when json parse failed.
|
||||||
|
"""
|
||||||
|
error_msg = (
|
||||||
|
f"Error parsing tool arguments as JSON: {str(e)}. "
|
||||||
|
"Please ensure the tool call arguments are valid JSON and try again."
|
||||||
|
)
|
||||||
|
content = TextContent(text=error_msg)
|
||||||
|
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
author=author,
|
||||||
|
content=[content],
|
||||||
|
recipient=Role.ASSISTANT,
|
||||||
|
channel=last_msg.channel,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SimpleContext(ConversationContext):
|
class SimpleContext(ConversationContext):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_output = None
|
self.last_output = None
|
||||||
@ -339,7 +362,13 @@ class HarmonyContext(ConversationContext):
|
|||||||
if isinstance(tool_session, Tool):
|
if isinstance(tool_session, Tool):
|
||||||
return await tool_session.get_result(self)
|
return await tool_session.get_result(self)
|
||||||
tool_name = last_msg.recipient.split(".")[1]
|
tool_name = last_msg.recipient.split(".")[1]
|
||||||
args = json.loads(last_msg.content[0].text)
|
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||||
|
try:
|
||||||
|
args = json.loads(last_msg.content[0].text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return _create_json_parse_error_messages(last_msg, e)
|
||||||
|
else:
|
||||||
|
args = json.loads(last_msg.content[0].text)
|
||||||
result = await tool_session.call_tool(tool_name, args)
|
result = await tool_session.call_tool(tool_name, args)
|
||||||
result_str = result.content[0].text
|
result_str = result.content[0].text
|
||||||
content = TextContent(text=result_str)
|
content = TextContent(text=result_str)
|
||||||
@ -420,7 +449,13 @@ class HarmonyContext(ConversationContext):
|
|||||||
if isinstance(tool_session, Tool):
|
if isinstance(tool_session, Tool):
|
||||||
return await tool_session.get_result(self)
|
return await tool_session.get_result(self)
|
||||||
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
|
||||||
args = json.loads(last_msg.content[0].text)
|
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
|
||||||
|
try:
|
||||||
|
args = json.loads(last_msg.content[0].text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return _create_json_parse_error_messages(last_msg, e)
|
||||||
|
else:
|
||||||
|
args = json.loads(last_msg.content[0].text)
|
||||||
result = await tool_session.call_tool(tool_name, args)
|
result = await tool_session.call_tool(tool_name, args)
|
||||||
result_str = result.content[0].text
|
result_str = result.content[0].text
|
||||||
content = TextContent(text=result_str)
|
content = TextContent(text=result_str)
|
||||||
|
|||||||
@ -340,7 +340,24 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
|||||||
if len(message.content) != 1:
|
if len(message.content) != 1:
|
||||||
raise ValueError("Invalid number of contents in browser message")
|
raise ValueError("Invalid number of contents in browser message")
|
||||||
content = message.content[0]
|
content = message.content[0]
|
||||||
browser_call = json.loads(content.text)
|
# We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY
|
||||||
|
# env variable since if it is not set, we are certain the json is valid
|
||||||
|
# The use of Actions for web search will be removed entirely in
|
||||||
|
# the future, so this is only necessary temporarily
|
||||||
|
try:
|
||||||
|
browser_call = json.loads(content.text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If the content is not valid JSON, then it was
|
||||||
|
# caught and retried by vLLM, which means we
|
||||||
|
# need to make note of that so the user is aware
|
||||||
|
json_retry_output_message = (
|
||||||
|
f"Invalid JSON args, caught and retried: {content.text}"
|
||||||
|
)
|
||||||
|
browser_call = {
|
||||||
|
"query": json_retry_output_message,
|
||||||
|
"url": json_retry_output_message,
|
||||||
|
"pattern": json_retry_output_message,
|
||||||
|
}
|
||||||
# TODO: translate to url properly!
|
# TODO: translate to url properly!
|
||||||
if recipient == "browser.search":
|
if recipient == "browser.search":
|
||||||
action = ActionSearch(
|
action = ActionSearch(
|
||||||
|
|||||||
@ -199,6 +199,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||||
VLLM_TUNED_CONFIG_FOLDER: str | None = None
|
VLLM_TUNED_CONFIG_FOLDER: str | None = None
|
||||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||||
|
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
|
||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
|
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
|
||||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||||
@ -1331,6 +1332,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool(
|
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool(
|
||||||
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))
|
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))
|
||||||
),
|
),
|
||||||
|
# Enable automatic retry when tool call JSON parsing fails
|
||||||
|
# If enabled, returns an error message to the model to retry
|
||||||
|
# If disabled (default), raises an exception and fails the request
|
||||||
|
"VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY", "0"))
|
||||||
|
),
|
||||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(
|
"VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(
|
||||||
int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))
|
int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user