mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:15:01 +08:00
[Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324)
This commit is contained in:
parent
6024d115cd
commit
fb691ee4e7
@ -36,21 +36,41 @@ def monkeypatch_module():
|
|||||||
mpatch.undo()
|
mpatch.undo()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module",
|
||||||
|
params=[True, False],
|
||||||
|
ids=["with_tool_parser", "without_tool_parser"])
|
||||||
|
def with_tool_parser(request) -> bool:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
|
def default_server_args(with_tool_parser: bool):
|
||||||
with monkeypatch_module.context() as m:
|
args = [
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
# use half precision for speed and memory savings in CI environment
|
||||||
args = [
|
"--enforce-eager",
|
||||||
"--enforce-eager",
|
"--max-model-len",
|
||||||
"--max-model-len",
|
"4096",
|
||||||
"8192",
|
"--reasoning-parser",
|
||||||
|
"openai_gptoss",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.8",
|
||||||
|
]
|
||||||
|
if with_tool_parser:
|
||||||
|
args.extend([
|
||||||
"--tool-call-parser",
|
"--tool-call-parser",
|
||||||
"openai",
|
"openai",
|
||||||
"--reasoning-parser",
|
|
||||||
"openai_gptoss",
|
|
||||||
"--enable-auto-tool-choice",
|
"--enable-auto-tool-choice",
|
||||||
]
|
])
|
||||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
|
||||||
|
default_server_args: list[str]):
|
||||||
|
with monkeypatch_module.context() as m:
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||||
|
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
|
||||||
|
default_server_args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
|
||||||
|
with_tool_parser: bool):
|
||||||
tools = [{
|
tools = [{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
|||||||
]
|
]
|
||||||
|
|
||||||
stream = await gptoss_client.chat.completions.create(
|
stream = await gptoss_client.chat.completions.create(
|
||||||
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
|
model=GPT_OSS_MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools if with_tool_parser else None,
|
||||||
|
stream=True)
|
||||||
|
|
||||||
name = None
|
name = None
|
||||||
args_buf = ""
|
args_buf = ""
|
||||||
|
content_buf = ""
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
if delta.tool_calls:
|
if delta.tool_calls:
|
||||||
@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
|||||||
name = tc.function.name
|
name = tc.function.name
|
||||||
if tc.function and tc.function.arguments:
|
if tc.function and tc.function.arguments:
|
||||||
args_buf += tc.function.arguments
|
args_buf += tc.function.arguments
|
||||||
|
if getattr(delta, "content", None):
|
||||||
assert name is not None
|
content_buf += delta.content
|
||||||
assert len(args_buf) > 0
|
if with_tool_parser:
|
||||||
|
assert name is not None
|
||||||
|
assert len(args_buf) > 0
|
||||||
|
else:
|
||||||
|
assert name is None
|
||||||
|
assert len(args_buf) == 0
|
||||||
|
assert len(content_buf) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||||
|
with_tool_parser: bool):
|
||||||
|
if not with_tool_parser:
|
||||||
|
pytest.skip("skip non-tool for multi-turn tests")
|
||||||
tools = [{
|
tools = [{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
|||||||
)
|
)
|
||||||
second_msg = second.choices[0].message
|
second_msg = second.choices[0].message
|
||||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
|
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME = "openai-community/gpt2"
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from collections.abc import Sequence as GenericSequence
|
from collections.abc import Sequence as GenericSequence
|
||||||
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
|
from typing import Callable, Final, Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
@ -1174,6 +1174,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
token_ids = output.token_ids
|
token_ids = output.token_ids
|
||||||
out_logprobs = output.logprobs
|
out_logprobs = output.logprobs
|
||||||
|
tool_call_info = None
|
||||||
|
|
||||||
if request.logprobs and request.top_logprobs is not None:
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
assert out_logprobs is not None, "Did not output logprobs"
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
@ -1188,32 +1189,42 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
if TYPE_CHECKING:
|
if self.tool_parser is not None:
|
||||||
assert self.tool_parser is not None
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
tool_parser = self.tool_parser(tokenizer)
|
# NOTE: We use token_ids for openai tool parser
|
||||||
# NOTE: We use token_ids for openai tool parser
|
tool_call_info = tool_parser.extract_tool_calls(
|
||||||
tool_call_info = tool_parser.extract_tool_calls(
|
"",
|
||||||
"",
|
request=request,
|
||||||
request=request,
|
token_ids=token_ids, # type: ignore
|
||||||
token_ids=token_ids, # type: ignore
|
)
|
||||||
)
|
reasoning_content, content = None, tool_call_info.content
|
||||||
reasoning_content, content = None, tool_call_info.content
|
if request.include_reasoning:
|
||||||
if request.include_reasoning:
|
reasoning_content, content, _ = parse_chat_output(
|
||||||
|
token_ids)
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_call_info.tool_calls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
reasoning_content, content, _ = parse_chat_output(
|
reasoning_content, content, _ = parse_chat_output(
|
||||||
token_ids)
|
token_ids)
|
||||||
message = ChatMessage(
|
if not request.include_reasoning:
|
||||||
role=role,
|
reasoning_content = None
|
||||||
reasoning_content=reasoning_content,
|
message = ChatMessage(
|
||||||
content=content,
|
role=role,
|
||||||
tool_calls=tool_call_info.tool_calls,
|
reasoning_content=reasoning_content,
|
||||||
)
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
message=message,
|
message=message,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason="tool_calls"
|
finish_reason="tool_calls" if
|
||||||
if tool_call_info.tools_called else
|
(tool_call_info is not None
|
||||||
|
and tool_call_info.tools_called) else
|
||||||
output.finish_reason if output.finish_reason else "stop",
|
output.finish_reason if output.finish_reason else "stop",
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user