[Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324)

This commit is contained in:
Aaron Pham 2025-09-06 15:10:32 -04:00 committed by GitHub
parent 6024d115cd
commit fb691ee4e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 38 deletions

View File

@ -36,21 +36,41 @@ def monkeypatch_module():
mpatch.undo()
@pytest.fixture(scope="module",
params=[True, False],
ids=["with_tool_parser", "without_tool_parser"])
def with_tool_parser(request) -> bool:
return request.param
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
args = [
"--enforce-eager",
"--max-model-len",
"8192",
def default_server_args(with_tool_parser: bool):
args = [
# use half precision for speed and memory savings in CI environment
"--enforce-eager",
"--max-model-len",
"4096",
"--reasoning-parser",
"openai_gptoss",
"--gpu-memory-utilization",
"0.8",
]
if with_tool_parser:
args.extend([
"--tool-call-parser",
"openai",
"--reasoning-parser",
"openai_gptoss",
"--enable-auto-tool-choice",
]
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
])
return args
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
default_server_args: list[str]):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
@pytest.mark.asyncio
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
with_tool_parser: bool):
tools = [{
"type": "function",
"function": {
@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
]
stream = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools if with_tool_parser else None,
stream=True)
name = None
args_buf = ""
content_buf = ""
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.tool_calls:
@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
name = tc.function.name
if tc.function and tc.function.arguments:
args_buf += tc.function.arguments
assert name is not None
assert len(args_buf) > 0
if getattr(delta, "content", None):
content_buf += delta.content
if with_tool_parser:
assert name is not None
assert len(args_buf) > 0
else:
assert name is None
assert len(args_buf) == 0
assert len(content_buf) > 0
@pytest.mark.asyncio
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
with_tool_parser: bool):
if not with_tool_parser:
pytest.skip("skip non-tool for multi-turn tests")
tools = [{
"type": "function",
"function": {
@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
)
second_msg = second.choices[0].message
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
MODEL_NAME = "openai-community/gpt2"

View File

@ -6,7 +6,7 @@ import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
from typing import Callable, Final, Optional, Union
import jinja2
import partial_json_parser
@ -1174,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing):
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
tool_call_info = None
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
@ -1188,32 +1189,42 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None
if self.use_harmony:
if TYPE_CHECKING:
assert self.tool_parser is not None
tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls(
"",
request=request,
token_ids=token_ids, # type: ignore
)
reasoning_content, content = None, tool_call_info.content
if request.include_reasoning:
if self.tool_parser is not None:
tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls(
"",
request=request,
token_ids=token_ids, # type: ignore
)
reasoning_content, content = None, tool_call_info.content
if request.include_reasoning:
reasoning_content, content, _ = parse_chat_output(
token_ids)
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=content,
tool_calls=tool_call_info.tool_calls,
)
else:
reasoning_content, content, _ = parse_chat_output(
token_ids)
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=content,
tool_calls=tool_call_info.tool_calls,
)
if not request.include_reasoning:
reasoning_content = None
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=content,
)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls"
if tool_call_info.tools_called else
finish_reason="tool_calls" if
(tool_call_info is not None
and tool_call_info.tools_called) else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason,
)