mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 01:05:59 +08:00
98 lines
3.4 KiB
Python
98 lines
3.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import json
|
|
from collections.abc import Sequence
|
|
from typing import TYPE_CHECKING
|
|
|
|
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ChatCompletionRequest,
|
|
DeltaMessage,
|
|
ExtractedToolCallInformation,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
|
ToolParser,
|
|
)
|
|
from vllm.logger import init_logger
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.tokenizers import TokenizerLike
|
|
else:
|
|
TokenizerLike = object
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class OpenAIToolParser(ToolParser):
|
|
def __init__(self, tokenizer: "TokenizerLike"):
|
|
super().__init__(tokenizer)
|
|
|
|
def extract_tool_calls(
|
|
self,
|
|
model_output: str,
|
|
request: ChatCompletionRequest,
|
|
token_ids: Sequence[int] | None = None,
|
|
) -> ExtractedToolCallInformation:
|
|
if token_ids is None:
|
|
raise NotImplementedError(
|
|
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
|
|
)
|
|
|
|
parser = parse_output_into_messages(token_ids)
|
|
tool_calls = []
|
|
final_content = None
|
|
|
|
if len(parser.messages) > 0:
|
|
for msg in parser.messages:
|
|
if len(msg.content) < 1:
|
|
continue
|
|
msg_text = msg.content[0].text
|
|
if msg.recipient and msg.recipient.startswith("functions."):
|
|
# If no content-type is given assume JSON, as that's the
|
|
# most common case with gpt-oss models.
|
|
if not msg.content_type or "json" in msg.content_type:
|
|
# load and dump the JSON text to check validity and
|
|
# remove any extra newlines or other odd formatting
|
|
try:
|
|
tool_args = json.dumps(json.loads(msg_text))
|
|
except json.JSONDecodeError:
|
|
logger.exception(
|
|
"Error decoding JSON tool call from response."
|
|
)
|
|
tool_args = msg_text
|
|
else:
|
|
tool_args = msg_text
|
|
tool_calls.append(
|
|
ToolCall(
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=msg.recipient.split("functions.")[1],
|
|
arguments=tool_args,
|
|
),
|
|
)
|
|
)
|
|
elif msg.channel == "final":
|
|
final_content = msg_text
|
|
|
|
return ExtractedToolCallInformation(
|
|
tools_called=len(tool_calls) > 0,
|
|
tool_calls=tool_calls,
|
|
content=final_content,
|
|
)
|
|
|
|
def extract_tool_calls_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int],
|
|
current_token_ids: Sequence[int],
|
|
delta_token_ids: Sequence[int],
|
|
request: ChatCompletionRequest,
|
|
) -> DeltaMessage | None:
|
|
raise NotImplementedError(
|
|
"Not being used, manual parsing in serving_chat.py" # noqa: E501
|
|
)
|