[Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324)

This commit is contained in:
Aaron Pham 2025-09-06 15:10:32 -04:00 committed by GitHub
parent 6024d115cd
commit fb691ee4e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 38 deletions

View File

@ -36,21 +36,41 @@ def monkeypatch_module():
mpatch.undo() mpatch.undo()
@pytest.fixture(scope="module",
params=[True, False],
ids=["with_tool_parser", "without_tool_parser"])
def with_tool_parser(request) -> bool:
return request.param
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch): def default_server_args(with_tool_parser: bool):
with monkeypatch_module.context() as m: args = [
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") # use half precision for speed and memory savings in CI environment
args = [ "--enforce-eager",
"--enforce-eager", "--max-model-len",
"--max-model-len", "4096",
"8192", "--reasoning-parser",
"openai_gptoss",
"--gpu-memory-utilization",
"0.8",
]
if with_tool_parser:
args.extend([
"--tool-call-parser", "--tool-call-parser",
"openai", "openai",
"--reasoning-parser",
"openai_gptoss",
"--enable-auto-tool-choice", "--enable-auto-tool-choice",
] ])
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server: return args
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
default_server_args: list[str]):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
default_server_args) as remote_server:
yield remote_server yield remote_server
@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI): async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
with_tool_parser: bool):
tools = [{ tools = [{
"type": "function", "type": "function",
"function": { "function": {
@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
] ]
stream = await gptoss_client.chat.completions.create( stream = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True) model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools if with_tool_parser else None,
stream=True)
name = None name = None
args_buf = "" args_buf = ""
content_buf = ""
async for chunk in stream: async for chunk in stream:
delta = chunk.choices[0].delta delta = chunk.choices[0].delta
if delta.tool_calls: if delta.tool_calls:
@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
name = tc.function.name name = tc.function.name
if tc.function and tc.function.arguments: if tc.function and tc.function.arguments:
args_buf += tc.function.arguments args_buf += tc.function.arguments
if getattr(delta, "content", None):
assert name is not None content_buf += delta.content
assert len(args_buf) > 0 if with_tool_parser:
assert name is not None
assert len(args_buf) > 0
else:
assert name is None
assert len(args_buf) == 0
assert len(content_buf) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI): async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
with_tool_parser: bool):
if not with_tool_parser:
pytest.skip("skip non-tool for multi-turn tests")
tools = [{ tools = [{
"type": "function", "type": "function",
"function": { "function": {
@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
) )
second_msg = second.choices[0].message second_msg = second.choices[0].message
assert (second_msg.content is not None and len(second_msg.content) > 0) or \ assert (second_msg.content is not None and len(second_msg.content) > 0) or \
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501 (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"

View File

@ -6,7 +6,7 @@ import json
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import TYPE_CHECKING, Callable, Final, Optional, Union from typing import Callable, Final, Optional, Union
import jinja2 import jinja2
import partial_json_parser import partial_json_parser
@ -1174,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing):
for output in final_res.outputs: for output in final_res.outputs:
token_ids = output.token_ids token_ids = output.token_ids
out_logprobs = output.logprobs out_logprobs = output.logprobs
tool_call_info = None
if request.logprobs and request.top_logprobs is not None: if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs" assert out_logprobs is not None, "Did not output logprobs"
@ -1188,32 +1189,42 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None logprobs = None
if self.use_harmony: if self.use_harmony:
if TYPE_CHECKING: if self.tool_parser is not None:
assert self.tool_parser is not None tool_parser = self.tool_parser(tokenizer)
tool_parser = self.tool_parser(tokenizer) # NOTE: We use token_ids for openai tool parser
# NOTE: We use token_ids for openai tool parser tool_call_info = tool_parser.extract_tool_calls(
tool_call_info = tool_parser.extract_tool_calls( "",
"", request=request,
request=request, token_ids=token_ids, # type: ignore
token_ids=token_ids, # type: ignore )
) reasoning_content, content = None, tool_call_info.content
reasoning_content, content = None, tool_call_info.content if request.include_reasoning:
if request.include_reasoning: reasoning_content, content, _ = parse_chat_output(
token_ids)
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=content,
tool_calls=tool_call_info.tool_calls,
)
else:
reasoning_content, content, _ = parse_chat_output( reasoning_content, content, _ = parse_chat_output(
token_ids) token_ids)
message = ChatMessage( if not request.include_reasoning:
role=role, reasoning_content = None
reasoning_content=reasoning_content, message = ChatMessage(
content=content, role=role,
tool_calls=tool_call_info.tool_calls, reasoning_content=reasoning_content,
) content=content,
)
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
message=message, message=message,
logprobs=logprobs, logprobs=logprobs,
finish_reason="tool_calls" finish_reason="tool_calls" if
if tool_call_info.tools_called else (tool_call_info is not None
and tool_call_info.tools_called) else
output.finish_reason if output.finish_reason else "stop", output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
) )