diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 801c82b4fad46..c1b0a084f33f5 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import datetime +from collections.abc import Iterable from typing import Literal, Optional from openai.types.responses.tool import Tool @@ -109,3 +110,36 @@ def get_stop_tokens_for_assistant_actions() -> list[int]: def get_streamable_parser_for_assistant() -> StreamableParser: return StreamableParser(get_encoding(), role=Role.ASSISTANT) + + +def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: + parser = get_streamable_parser_for_assistant() + for token_id in token_ids: + parser.process(token_id) + return parser + + +def parse_chat_output( + token_ids: list[int]) -> tuple[Optional[str], Optional[str], bool]: + parser = parse_output_into_messages(token_ids) + output_msgs = parser.messages + if len(output_msgs) == 0: + # The generation has stopped during reasoning. + is_tool_call = False + reasoning_content = parser.current_content + final_content = None + elif len(output_msgs) == 1: + # The generation has stopped during final message. + is_tool_call = False + reasoning_content = output_msgs[0].content[0].text + final_content = parser.current_content + else: + if len(output_msgs) != 2: + raise ValueError( + "Expected 2 output messages (reasoning and final), " + f"but got {len(output_msgs)}.") + reasoning_msg, final_msg = output_msgs + reasoning_content = reasoning_msg.content[0].text + final_content = final_msg.content[0].text + is_tool_call = final_msg.recipient is not None + return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 64f2beb14021a..57aa427207568 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -323,6 +323,7 @@ class ResponsesRequest(OpenAIBaseModel): if (top_p := self.top_p) is None: top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output guided_decoding = None @@ -340,6 +341,7 @@ class ResponsesRequest(OpenAIBaseModel): top_p=top_p, max_tokens=max_tokens, logprobs=self.top_logprobs, + stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), guided_decoding=guided_decoding, @@ -404,6 +406,8 @@ class ChatCompletionRequest(OpenAIBaseModel): Literal["required"], ChatCompletionNamedToolChoiceParam, ]] = "none" + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + include_reasoning: bool = True # NOTE this will be ignored by vLLM -- the model determines the behavior parallel_tool_calls: Optional[bool] = False diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e1d8a31672ed3..6ad0a8ec54f7c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,6 +12,7 @@ import jinja2 import partial_json_parser import regex as re from fastapi import Request +from openai_harmony import Message as OpenAIMessage from pydantic import TypeAdapter from vllm.config import ModelConfig @@ -19,6 +20,10 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, random_tool_call_id) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, get_system_message, parse_chat_input, + parse_chat_output, render_for_completion) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -125,6 +131,23 @@ class OpenAIServingChat(OpenAIServing): logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + if self.use_harmony: + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions()) + + # NOTE(woosuk): While OpenAI's chat completion API supports browsing + # for some models, currently vLLM doesn't support it. Please use the + # Responses API instead. + self.supports_browsing = False + self.browser_tool = None + # NOTE(woosuk): Chat completion API does not support code interpreter. + # Please use the Responses API instead. + self.supports_code_interpreter = False + self.python_tool = None + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -169,7 +192,8 @@ class OpenAIServingChat(OpenAIServing): if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer)): + and not isinstance(tokenizer, MistralTokenizer) + and not self.use_harmony): # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( @@ -184,25 +208,35 @@ class OpenAIServingChat(OpenAIServing): else: tool_dicts = [tool.model_dump() for tool in request.tools] - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=request.chat_template_kwargs, - tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if not self.use_harmony: + # Common case. + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + # For GPT-OSS. + ( + conversation, + request_prompts, + engine_prompts, + ) = self._make_request_with_harmony(request) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -436,6 +470,11 @@ class OpenAIServingChat(OpenAIServing): finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [ + get_streamable_parser_for_assistant() + for _ in range(num_choices) + ] if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -597,7 +636,18 @@ class OpenAIServingChat(OpenAIServing): else: logprobs = None - delta_text = output.text + if self.use_harmony: + harmony_parser = harmony_parsers[i] + for token_id in output.token_ids: + harmony_parser.process(token_id) + # FIXME(woosuk): Support function calling + is_final = harmony_parser.current_channel == "final" + if not (request.include_reasoning or is_final): + # Skip the reasoning content. + continue + delta_text = harmony_parser.last_content_delta or "" + else: + delta_text = output.text if not delta_text and not output.token_ids and \ not previous_num_tokens[i]: @@ -607,7 +657,8 @@ class OpenAIServingChat(OpenAIServing): delta_message: Optional[DeltaMessage] # just update previous_texts and previous_token_ids - if tool_choice_auto or self.reasoning_parser: + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -621,8 +672,14 @@ class OpenAIServingChat(OpenAIServing): else: current_token_ids = list(output.token_ids) + if self.use_harmony: + if is_final: + delta_message = DeltaMessage(content=delta_text) + else: + delta_message = DeltaMessage( + reasoning_content=delta_text) # handle streaming deltas for tools with named tool_choice - if tool_choice_function_name: + elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] and not reasoning_parser.is_reasoning_end( previous_token_ids)): @@ -990,7 +1047,38 @@ class OpenAIServingChat(OpenAIServing): ) else: logprobs = None - auto_tools_called = False + + if self.use_harmony: + reasoning_content, final_content, is_tool_call = ( + parse_chat_output(token_ids)) + if not request.include_reasoning: + reasoning_content = None + + if is_tool_call: + # TODO(woosuk): Implement tool call for gpt-oss. + # For now, only Responses API supports tool call for + # gpt-oss. + raise NotImplementedError( + "Tool call in Chat Completion API is not supported " + "for gpt-oss yet. Please use Responses API instead.") + else: + # Normal message + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=final_content, + ) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if is_tool_call else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + continue if self.reasoning_parser: try: @@ -1003,10 +1091,13 @@ class OpenAIServingChat(OpenAIServing): reasoning_content, content = ( reasoning_parser.extract_reasoning_content( output.text, request=request)) + if not request.include_reasoning: + reasoning_content = None else: reasoning_content = None content = output.text + auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used if (not self.enable_auto_tools or not self.tool_parser) and \ @@ -1261,3 +1352,33 @@ class OpenAIServingChat(OpenAIServing): and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) + + def _make_request_with_harmony( + self, + request: ChatCompletionRequest, + ): + messages: list[OpenAIMessage] = [] + + # Add system message. + # NOTE: In Chat Completion API, browsing is enabled by default + # if the model supports it. TODO: Support browsing. + assert not self.supports_browsing + assert not self.supports_code_interpreter + sys_msg = get_system_message( + reasoning_effort=request.reasoning_effort, + browser_description=None, + python_description=None) + messages.append(sys_msg) + + # Add developer message. + dev_msg = get_developer_message() + messages.append(dev_msg) + + # Add user message. + for chat_msg in request.messages: + messages.append(parse_chat_input(chat_msg)) + + # Render prompt token ids. + prompt_token_ids = render_for_completion(messages) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + return messages, [prompt_token_ids], [engine_prompt]