mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 02:14:30 +08:00
[gpt-oss] Support chat completion api (#22342)
This commit is contained in:
parent
54991c548a
commit
f263a4b53f
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import datetime
|
||||
from collections.abc import Iterable
|
||||
from typing import Literal, Optional
|
||||
|
||||
from openai.types.responses.tool import Tool
|
||||
@ -109,3 +110,36 @@ def get_stop_tokens_for_assistant_actions() -> list[int]:
|
||||
|
||||
def get_streamable_parser_for_assistant() -> StreamableParser:
|
||||
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
||||
|
||||
|
||||
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||
parser = get_streamable_parser_for_assistant()
|
||||
for token_id in token_ids:
|
||||
parser.process(token_id)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_chat_output(
|
||||
token_ids: list[int]) -> tuple[Optional[str], Optional[str], bool]:
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
output_msgs = parser.messages
|
||||
if len(output_msgs) == 0:
|
||||
# The generation has stopped during reasoning.
|
||||
is_tool_call = False
|
||||
reasoning_content = parser.current_content
|
||||
final_content = None
|
||||
elif len(output_msgs) == 1:
|
||||
# The generation has stopped during final message.
|
||||
is_tool_call = False
|
||||
reasoning_content = output_msgs[0].content[0].text
|
||||
final_content = parser.current_content
|
||||
else:
|
||||
if len(output_msgs) != 2:
|
||||
raise ValueError(
|
||||
"Expected 2 output messages (reasoning and final), "
|
||||
f"but got {len(output_msgs)}.")
|
||||
reasoning_msg, final_msg = output_msgs
|
||||
reasoning_content = reasoning_msg.content[0].text
|
||||
final_content = final_msg.content[0].text
|
||||
is_tool_call = final_msg.recipient is not None
|
||||
return reasoning_content, final_content, is_tool_call
|
||||
|
||||
@ -323,6 +323,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||
|
||||
# Structured output
|
||||
guided_decoding = None
|
||||
@ -340,6 +341,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=self.top_logprobs,
|
||||
stop_token_ids=stop_token_ids,
|
||||
output_kind=(RequestOutputKind.DELTA
|
||||
if self.stream else RequestOutputKind.FINAL_ONLY),
|
||||
guided_decoding=guided_decoding,
|
||||
@ -404,6 +406,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
Literal["required"],
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
]] = "none"
|
||||
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
|
||||
include_reasoning: bool = True
|
||||
|
||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
|
||||
@ -12,6 +12,7 @@ import jinja2
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from fastapi import Request
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@ -19,6 +20,10 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
random_tool_call_id)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
||||
parse_chat_output, render_for_completion)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
@ -125,6 +131,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.info("Using default chat sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
|
||||
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
if "stop_token_ids" not in self.default_sampling_params:
|
||||
self.default_sampling_params["stop_token_ids"] = []
|
||||
self.default_sampling_params["stop_token_ids"].extend(
|
||||
get_stop_tokens_for_assistant_actions())
|
||||
|
||||
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
|
||||
# for some models, currently vLLM doesn't support it. Please use the
|
||||
# Responses API instead.
|
||||
self.supports_browsing = False
|
||||
self.browser_tool = None
|
||||
# NOTE(woosuk): Chat completion API does not support code interpreter.
|
||||
# Please use the Responses API instead.
|
||||
self.supports_code_interpreter = False
|
||||
self.python_tool = None
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@ -169,7 +192,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if (request.tool_choice == "auto" and
|
||||
not (self.enable_auto_tools and tool_parser is not None)
|
||||
and not isinstance(tokenizer, MistralTokenizer)):
|
||||
and not isinstance(tokenizer, MistralTokenizer)
|
||||
and not self.use_harmony):
|
||||
# for hf tokenizers, "auto" tools requires
|
||||
# --enable-auto-tool-choice and --tool-call-parser
|
||||
return self.create_error_response(
|
||||
@ -184,25 +208,35 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
tool_dicts = [tool.model_dump() for tool in request.tools]
|
||||
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tool_dicts=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
if not self.use_harmony:
|
||||
# Common case.
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tool_dicts=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For GPT-OSS.
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = self._make_request_with_harmony(request)
|
||||
except (ValueError, TypeError, RuntimeError,
|
||||
jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@ -436,6 +470,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
num_cached_tokens = None
|
||||
if self.use_harmony:
|
||||
harmony_parsers = [
|
||||
get_streamable_parser_for_assistant()
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@ -597,7 +636,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[i]
|
||||
for token_id in output.token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
# FIXME(woosuk): Support function calling
|
||||
is_final = harmony_parser.current_channel == "final"
|
||||
if not (request.include_reasoning or is_final):
|
||||
# Skip the reasoning content.
|
||||
continue
|
||||
delta_text = harmony_parser.last_content_delta or ""
|
||||
else:
|
||||
delta_text = output.text
|
||||
|
||||
if not delta_text and not output.token_ids and \
|
||||
not previous_num_tokens[i]:
|
||||
@ -607,7 +657,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# just update previous_texts and previous_token_ids
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
if ((tool_choice_auto or self.reasoning_parser)
|
||||
and not self.use_harmony):
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
previous_text = previous_texts[i]
|
||||
@ -621,8 +672,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
current_token_ids = list(output.token_ids)
|
||||
|
||||
if self.use_harmony:
|
||||
if is_final:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if tool_choice_function_name:
|
||||
elif tool_choice_function_name:
|
||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||
and not reasoning_parser.is_reasoning_end(
|
||||
previous_token_ids)):
|
||||
@ -990,7 +1047,38 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
auto_tools_called = False
|
||||
|
||||
if self.use_harmony:
|
||||
reasoning_content, final_content, is_tool_call = (
|
||||
parse_chat_output(token_ids))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
|
||||
if is_tool_call:
|
||||
# TODO(woosuk): Implement tool call for gpt-oss.
|
||||
# For now, only Responses API supports tool call for
|
||||
# gpt-oss.
|
||||
raise NotImplementedError(
|
||||
"Tool call in Chat Completion API is not supported "
|
||||
"for gpt-oss yet. Please use Responses API instead.")
|
||||
else:
|
||||
# Normal message
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=final_content,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if is_tool_call else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
continue
|
||||
|
||||
if self.reasoning_parser:
|
||||
try:
|
||||
@ -1003,10 +1091,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
reasoning_content, content = (
|
||||
reasoning_parser.extract_reasoning_content(
|
||||
output.text, request=request))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
else:
|
||||
reasoning_content = None
|
||||
content = output.text
|
||||
|
||||
auto_tools_called = False
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||
@ -1261,3 +1352,33 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
):
|
||||
messages: list[OpenAIMessage] = []
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
# if the model supports it. TODO: Support browsing.
|
||||
assert not self.supports_browsing
|
||||
assert not self.supports_code_interpreter
|
||||
sys_msg = get_system_message(
|
||||
reasoning_effort=request.reasoning_effort,
|
||||
browser_description=None,
|
||||
python_description=None)
|
||||
messages.append(sys_msg)
|
||||
|
||||
# Add developer message.
|
||||
dev_msg = get_developer_message()
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.append(parse_chat_input(chat_msg))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
return messages, [prompt_token_ids], [engine_prompt]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user