vllm/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Andrew Xia 421125d03a
[ez] move harmony utils to parser folder (#30117)
Signed-off-by: Andrew Xia <axia@fb.com>
Co-authored-by: Andrew Xia <axia@fb.com>
2025-12-06 17:34:34 -05:00

98 lines
3.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.tokenizers import TokenizerLike
else:
TokenizerLike = object
logger = init_logger(__name__)
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
token_ids: Sequence[int] | None = None,
) -> ExtractedToolCallInformation:
if token_ids is None:
raise NotImplementedError(
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
)
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
if len(msg.content) < 1:
continue
msg_text = msg.content[0].text
if msg.recipient and msg.recipient.startswith("functions."):
# If no content-type is given assume JSON, as that's the
# most common case with gpt-oss models.
if not msg.content_type or "json" in msg.content_type:
# load and dump the JSON text to check validity and
# remove any extra newlines or other odd formatting
try:
tool_args = json.dumps(json.loads(msg_text))
except json.JSONDecodeError:
logger.exception(
"Error decoding JSON tool call from response."
)
tool_args = msg_text
else:
tool_args = msg_text
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
arguments=tool_args,
),
)
)
elif msg.channel == "final":
final_content = msg_text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=final_content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
raise NotImplementedError(
"Not being used, manual parsing in serving_chat.py" # noqa: E501
)