[Frontend] [gpt-oss] Tool json call parsing error retry (#27675)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: Alec Solder <alecs@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: Alec Solder <alecs@fb.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Alec S 2025-10-29 05:42:44 -04:00 committed by GitHub
parent 1891cf605a
commit 3c7fefdeba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 3 deletions

View File

@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Union
from openai.types.responses.tool import Mcp from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm import envs
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_encoding, get_encoding,
get_streamable_parser_for_assistant, get_streamable_parser_for_assistant,
@ -109,6 +110,28 @@ class ConversationContext(ABC):
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
def _create_json_parse_error_messages(
last_msg: Message, e: json.JSONDecodeError
) -> list[Message]:
"""
Creates an error message when json parse failed.
"""
error_msg = (
f"Error parsing tool arguments as JSON: {str(e)}. "
"Please ensure the tool call arguments are valid JSON and try again."
)
content = TextContent(text=error_msg)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(
author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel,
)
]
class SimpleContext(ConversationContext): class SimpleContext(ConversationContext):
def __init__(self): def __init__(self):
self.last_output = None self.last_output = None
@ -339,7 +362,13 @@ class HarmonyContext(ConversationContext):
if isinstance(tool_session, Tool): if isinstance(tool_session, Tool):
return await tool_session.get_result(self) return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1] tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text) if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args) result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text result_str = result.content[0].text
content = TextContent(text=result_str) content = TextContent(text=result_str)
@ -420,7 +449,13 @@ class HarmonyContext(ConversationContext):
if isinstance(tool_session, Tool): if isinstance(tool_session, Tool):
return await tool_session.get_result(self) return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0] tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
args = json.loads(last_msg.content[0].text) if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
except json.JSONDecodeError as e:
return _create_json_parse_error_messages(last_msg, e)
else:
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args) result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text result_str = result.content[0].text
content = TextContent(text=result_str) content = TextContent(text=result_str)

View File

@ -340,7 +340,24 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
if len(message.content) != 1: if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message") raise ValueError("Invalid number of contents in browser message")
content = message.content[0] content = message.content[0]
browser_call = json.loads(content.text) # We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY
# env variable since if it is not set, we are certain the json is valid
# The use of Actions for web search will be removed entirely in
# the future, so this is only necessary temporarily
try:
browser_call = json.loads(content.text)
except json.JSONDecodeError:
# If the content is not valid JSON, then it was
# caught and retried by vLLM, which means we
# need to make note of that so the user is aware
json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}"
)
browser_call = {
"query": json_retry_output_message,
"url": json_retry_output_message,
"pattern": json_retry_output_message,
}
# TODO: translate to url properly! # TODO: translate to url properly!
if recipient == "browser.search": if recipient == "browser.search":
action = ActionSearch( action = ActionSearch(

View File

@ -199,6 +199,7 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_TUNED_CONFIG_FOLDER: str | None = None
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
@ -1331,6 +1332,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool( "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool(
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0")) int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))
), ),
# Enable automatic retry when tool call JSON parsing fails
# If enabled, returns an error message to the model to retry
# If disabled (default), raises an exception and fails the request
"VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY": lambda: bool(
int(os.getenv("VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY", "0"))
),
# Add optional custom scopes for profiling, disable to avoid overheads # Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool( "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(
int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0")) int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))