add tools into TokenizeChatRequest (#18187)

Signed-off-by: yangxia <yangxiast@gmail.com>
This commit is contained in:
hustxiayang 2025-05-15 07:01:49 -04:00 committed by GitHub
parent 07ad27121f
commit 451da4bcbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 84 additions and 0 deletions

View File

@ -145,6 +145,83 @@ async def test_tokenize_chat(
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat_with_tools(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_generation in [False, True]:
for add_special in [False, True]:
conversation = [{
"role":
"user",
"content":
"What's the weather like in Paris today?",
}]
tools = [{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string"
}
},
},
},
}]
for continue_final in [False, True]:
if add_generation and continue_final:
continue
if continue_final:
conversation.append({
"role": "assistant",
"content": "Sure,"
})
prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
continue_final_message=continue_final,
conversation=conversation,
tools=tools,
tokenize=False,
)
tokens = tokenizer.encode(prompt,
add_special_tokens=add_special)
response = requests.post(
server.url_for("tokenize"),
json={
"add_generation_prompt": add_generation,
"continue_final_message": continue_final,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name,
"tools": tools,
},
)
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192,
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",

View File

@ -1593,6 +1593,10 @@ class TokenizeChatRequest(OpenAIBaseModel):
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
tools: Optional[list[ChatCompletionToolsParam]] = Field(
default=None,
description=("A list of tools the model may call."),
)
@model_validator(mode="before")
@classmethod

View File

@ -65,6 +65,8 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if isinstance(request, TokenizeChatRequest):
tool_dicts = (None if request.tools is None else
[tool.model_dump() for tool in request.tools])
(
_,
request_prompts,
@ -73,6 +75,7 @@ class OpenAIServingTokenization(OpenAIServing):
request,
tokenizer,
request.messages,
tool_dicts=tool_dicts,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,