mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324)
This commit is contained in:
parent
6024d115cd
commit
fb691ee4e7
@ -36,21 +36,41 @@ def monkeypatch_module():
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[True, False],
|
||||
ids=["with_tool_parser", "without_tool_parser"])
|
||||
def with_tool_parser(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
def default_server_args(with_tool_parser: bool):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
]
|
||||
if with_tool_parser:
|
||||
args.extend([
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
|
||||
])
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
|
||||
default_server_args: list[str]):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
|
||||
with_tool_parser: bool):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
]
|
||||
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools if with_tool_parser else None,
|
||||
stream=True)
|
||||
|
||||
name = None
|
||||
args_buf = ""
|
||||
content_buf = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.tool_calls:
|
||||
@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
name = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
args_buf += tc.function.arguments
|
||||
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
if getattr(delta, "content", None):
|
||||
content_buf += delta.content
|
||||
if with_tool_parser:
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
else:
|
||||
assert name is None
|
||||
assert len(args_buf) == 0
|
||||
assert len(content_buf) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||
with_tool_parser: bool):
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for multi-turn tests")
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
@ -6,7 +6,7 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
|
||||
from typing import Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
@ -1174,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
tool_call_info = None
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
@ -1188,32 +1189,42 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = None
|
||||
|
||||
if self.use_harmony:
|
||||
if TYPE_CHECKING:
|
||||
assert self.tool_parser is not None
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
# NOTE: We use token_ids for openai tool parser
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=request,
|
||||
token_ids=token_ids, # type: ignore
|
||||
)
|
||||
reasoning_content, content = None, tool_call_info.content
|
||||
if request.include_reasoning:
|
||||
if self.tool_parser is not None:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
# NOTE: We use token_ids for openai tool parser
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=request,
|
||||
token_ids=token_ids, # type: ignore
|
||||
)
|
||||
reasoning_content, content = None, tool_call_info.content
|
||||
if request.include_reasoning:
|
||||
reasoning_content, content, _ = parse_chat_output(
|
||||
token_ids)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content,
|
||||
tool_calls=tool_call_info.tool_calls,
|
||||
)
|
||||
else:
|
||||
reasoning_content, content, _ = parse_chat_output(
|
||||
token_ids)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content,
|
||||
tool_calls=tool_call_info.tool_calls,
|
||||
)
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls"
|
||||
if tool_call_info.tools_called else
|
||||
finish_reason="tool_calls" if
|
||||
(tool_call_info is not None
|
||||
and tool_call_info.tools_called) else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user