mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 01:51:52 +08:00
[Bugfix] Fix tool_choice="none" being ignored by GPT-OSS/harmony models (#30867)
Signed-off-by: yujiepu <pyjapple@gmail.com> Signed-off-by: PlatinumGod <pyjapple@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
parent
45c0526ac9
commit
6a09612b2e
@ -52,8 +52,19 @@ def with_tool_parser(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[True],
|
||||
ids=["exclude_tools_when_tool_choice_none"],
|
||||
)
|
||||
def exclude_tools_when_tool_choice_none(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(with_tool_parser: bool):
|
||||
def default_server_args(
|
||||
with_tool_parser: bool, exclude_tools_when_tool_choice_none: bool
|
||||
):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--enforce-eager",
|
||||
@ -72,6 +83,8 @@ def default_server_args(with_tool_parser: bool):
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
)
|
||||
if exclude_tools_when_tool_choice_none:
|
||||
args.append("--exclude-tools-when-tool-choice-none")
|
||||
return args
|
||||
|
||||
|
||||
@ -335,6 +348,69 @@ async def test_gpt_oss_tool_message_array_content(
|
||||
assert response_multi_array.choices[0].message is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_tool_choice_none(
|
||||
gptoss_client: OpenAI,
|
||||
with_tool_parser: bool,
|
||||
exclude_tools_when_tool_choice_none: bool,
|
||||
):
|
||||
if not (with_tool_parser and exclude_tools_when_tool_choice_none):
|
||||
pytest.skip(
|
||||
"skip tool_choice tests when non-tool or "
|
||||
"--exclude-tools-when-tool-choice-none not set"
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the temperature(in degrees Celsius) in Dallas?",
|
||||
},
|
||||
]
|
||||
|
||||
tool_choice_auto = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
msg = tool_choice_auto.choices[0].message
|
||||
assert len(msg.tool_calls) == 1
|
||||
|
||||
tool_choice_none = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="none",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
msg = tool_choice_none.choices[0].message
|
||||
assert len(msg.tool_calls) == 0
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
|
||||
@ -299,7 +299,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
# For GPT-OSS.
|
||||
conversation, engine_prompts = self._make_request_with_harmony(request)
|
||||
should_include_tools = tool_dicts is not None
|
||||
conversation, engine_prompts = self._make_request_with_harmony(
|
||||
request, should_include_tools
|
||||
)
|
||||
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
@ -1833,6 +1836,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
should_include_tools: bool = True,
|
||||
):
|
||||
messages: list[OpenAIMessage] = []
|
||||
|
||||
@ -1850,12 +1854,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
reasoning_effort=request.reasoning_effort,
|
||||
browser_description=None,
|
||||
python_description=None,
|
||||
with_custom_tools=request.tools is not None,
|
||||
with_custom_tools=should_include_tools,
|
||||
)
|
||||
messages.append(sys_msg)
|
||||
|
||||
# Add developer message.
|
||||
dev_msg = get_developer_message(tools=request.tools)
|
||||
dev_msg = get_developer_message(
|
||||
tools=request.tools if should_include_tools else None
|
||||
)
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user