From fb691ee4e776a5fa6780e3752884dc5e0c5ccda1 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 6 Sep 2025 15:10:32 -0400 Subject: [PATCH] [Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324) --- tests/entrypoints/openai/test_serving_chat.py | 70 ++++++++++++++----- vllm/entrypoints/openai/serving_chat.py | 51 ++++++++------ 2 files changed, 83 insertions(+), 38 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c609cfb5c067..04805dbca74f 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -36,21 +36,41 @@ def monkeypatch_module(): mpatch.undo() +@pytest.fixture(scope="module", + params=[True, False], + ids=["with_tool_parser", "without_tool_parser"]) +def with_tool_parser(request) -> bool: + return request.param + + @pytest.fixture(scope="module") -def gptoss_server(monkeypatch_module: pytest.MonkeyPatch): - with monkeypatch_module.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") - args = [ - "--enforce-eager", - "--max-model-len", - "8192", +def default_server_args(with_tool_parser: bool): + args = [ + # use half precision for speed and memory savings in CI environment + "--enforce-eager", + "--max-model-len", + "4096", + "--reasoning-parser", + "openai_gptoss", + "--gpu-memory-utilization", + "0.8", + ] + if with_tool_parser: + args.extend([ "--tool-call-parser", "openai", - "--reasoning-parser", - "openai_gptoss", "--enable-auto-tool-choice", - ] - with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server: + ]) + return args + + +@pytest.fixture(scope="module") +def gptoss_server(monkeypatch_module: pytest.MonkeyPatch, + default_server_args: list[str]): + with monkeypatch_module.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, + default_server_args) as remote_server: yield remote_server @@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server): @pytest.mark.asyncio -async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI): +async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI, + with_tool_parser: bool): tools = [{ "type": "function", "function": { @@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI): ] stream = await gptoss_client.chat.completions.create( - model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True) + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools if with_tool_parser else None, + stream=True) name = None args_buf = "" + content_buf = "" async for chunk in stream: delta = chunk.choices[0].delta if delta.tool_calls: @@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI): name = tc.function.name if tc.function and tc.function.arguments: args_buf += tc.function.arguments - - assert name is not None - assert len(args_buf) > 0 + if getattr(delta, "content", None): + content_buf += delta.content + if with_tool_parser: + assert name is not None + assert len(args_buf) > 0 + else: + assert name is None + assert len(args_buf) == 0 + assert len(content_buf) > 0 @pytest.mark.asyncio -async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI): +async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, + with_tool_parser: bool): + if not with_tool_parser: + pytest.skip("skip non-tool for multi-turn tests") tools = [{ "type": "function", "function": { @@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI): ) second_msg = second.choices[0].message assert (second_msg.content is not None and len(second_msg.content) > 0) or \ - (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501 + (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) MODEL_NAME = "openai-community/gpt2" diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4cc22787a020..5c7adc53f49b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ import json import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import TYPE_CHECKING, Callable, Final, Optional, Union +from typing import Callable, Final, Optional, Union import jinja2 import partial_json_parser @@ -1174,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing): for output in final_res.outputs: token_ids = output.token_ids out_logprobs = output.logprobs + tool_call_info = None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" @@ -1188,32 +1189,42 @@ class OpenAIServingChat(OpenAIServing): logprobs = None if self.use_harmony: - if TYPE_CHECKING: - assert self.tool_parser is not None - tool_parser = self.tool_parser(tokenizer) - # NOTE: We use token_ids for openai tool parser - tool_call_info = tool_parser.extract_tool_calls( - "", - request=request, - token_ids=token_ids, # type: ignore - ) - reasoning_content, content = None, tool_call_info.content - if request.include_reasoning: + if self.tool_parser is not None: + tool_parser = self.tool_parser(tokenizer) + # NOTE: We use token_ids for openai tool parser + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + reasoning_content, content = None, tool_call_info.content + if request.include_reasoning: + reasoning_content, content, _ = parse_chat_output( + token_ids) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + tool_calls=tool_call_info.tool_calls, + ) + else: reasoning_content, content, _ = parse_chat_output( token_ids) - message = ChatMessage( - role=role, - reasoning_content=reasoning_content, - content=content, - tool_calls=tool_call_info.tool_calls, - ) + if not request.include_reasoning: + reasoning_content = None + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, + ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" - if tool_call_info.tools_called else + finish_reason="tool_calls" if + (tool_call_info is not None + and tool_call_info.tools_called) else output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason, )