diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py new file mode 100644 index 000000000000..03e1f1fadd73 --- /dev/null +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from mistral_common.protocol.instruct.messages import UserMessage +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.instruct.tool_calls import Function, Tool + +from vllm.transformers_utils.tokenizers.mistral import ( + make_mistral_chat_completion_request) + + +# yapf: enable +@pytest.mark.parametrize( + "openai_request,expected_mistral_request", + [( + { + "messages": [{ + "role": "user", + "content": "What is the current local date and time?", + }], + "tools": [{ + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + }, + }], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What is the current local date and time?") + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_current_time", + description="Fetch the current local date and time.", + parameters={}, + ), + ) + ], + ), + )], +) +def test_make_mistral_chat_completion_request(openai_request, + expected_mistral_request): + assert (make_mistral_chat_completion_request( + openai_request["messages"], + openai_request["tools"]) == expected_mistral_request) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 8d96fcd278e6..f08923e7401f 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]): return matched_files[0] +def make_mistral_chat_completion_request( + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, + Any]]] = None) -> "ChatCompletionRequest": + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + last_message = cast(Dict[str, Any], messages[-1]) + if last_message["role"] == "assistant": + last_message["prefix"] = True + + # mistral-common requires AssistantMessage content to be string [1]. + # + # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 + for message in messages: + if message.get("role") == "assistant": + content = message.get("content") + if isinstance(content, list): + content = "\n".join(chunk.get("text") for chunk in content) + message["content"] = content + + # The Mistral client, in comparison to the OpenAI client, requires the + # "parameters" dict to be present, even if it's empty. + if tools: + for function in [ + tool["function"] for tool in tools + if tool["type"] == "function" + ]: + function.setdefault("parameters", {}) + + from mistral_common.protocol.instruct.request import ChatCompletionRequest + return ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] + + class MistralTokenizer: def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: @@ -283,27 +319,10 @@ class MistralTokenizer: def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], - tools: Optional[Dict[str, Any]] = None, + tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> List[int]: - last_message = cast(Dict[str, Any], messages[-1]) - if last_message["role"] == "assistant": - last_message["prefix"] = True - - from mistral_common.protocol.instruct.request import ( - ChatCompletionRequest) - - # mistral-common requires AssistantMessage content to be string [1]. - # - # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 - for message in messages: - if message.get("role") == "assistant": - content = message.get("content") - if isinstance(content, list): - content = "\n".join(chunk.get("text") for chunk in content) - message["content"] = content - request = ChatCompletionRequest(messages=messages, - tools=tools) # type: ignore[type-var] + request = make_mistral_chat_completion_request(messages, tools) encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt