mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 15:11:19 +08:00
[Frontend] Clean up type annotations for mistral tokenizer (#8314)
This commit is contained in:
parent
6234385f4a
commit
8c054b7a62
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
|
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||||
|
load_chat_template)
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
add_generation_prompt=add_generation_prompt)
|
add_generation_prompt=add_generation_prompt)
|
||||||
|
|
||||||
# Call the function and get the result
|
# Call the function and get the result
|
||||||
result = apply_chat_template(
|
result = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
chat_template=mock_request.chat_template or template_content,
|
chat_template=mock_request.chat_template or template_content,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
# pydantic needs the TypedDict from typing_extensions
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
from typing_extensions import Required, TypeAlias, TypedDict
|
from typing_extensions import Required, TypeAlias, TypedDict
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
|
|||||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||||
async_get_and_parse_image,
|
async_get_and_parse_image,
|
||||||
get_and_parse_audio, get_and_parse_image)
|
get_and_parse_audio, get_and_parse_image)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
|
|||||||
audio_url = _AudioParser(part)["audio_url"]
|
audio_url = _AudioParser(part)["audio_url"]
|
||||||
|
|
||||||
mm_parser.parse_audio(audio_url["url"])
|
mm_parser.parse_audio(audio_url["url"])
|
||||||
|
elif part_type == "refusal":
|
||||||
|
text = _RefusalParser(part)["refusal"]
|
||||||
|
texts.append(text)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
@ -433,6 +437,21 @@ def _parse_chat_message_content(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||||
|
# per the Transformers docs & maintainers, tool call arguments in
|
||||||
|
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||||
|
# this is how tool-use chat templates will expect them moving forwards
|
||||||
|
# so, for messages that have tool_calls, parse the string (which we get
|
||||||
|
# from openAI format) to dict
|
||||||
|
for message in messages:
|
||||||
|
if (message["role"] == "assistant" and "tool_calls" in message
|
||||||
|
and isinstance(message["tool_calls"], list)):
|
||||||
|
|
||||||
|
for item in message["tool_calls"]:
|
||||||
|
item["function"]["arguments"] = json.loads(
|
||||||
|
item["function"]["arguments"])
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_messages(
|
def parse_chat_messages(
|
||||||
messages: List[ChatCompletionMessageParam],
|
messages: List[ChatCompletionMessageParam],
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@ -446,6 +465,8 @@ def parse_chat_messages(
|
|||||||
|
|
||||||
conversation.extend(sub_messages)
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
_postprocess_messages(conversation)
|
||||||
|
|
||||||
return conversation, mm_tracker.all_mm_data()
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
@ -462,41 +483,41 @@ def parse_chat_messages_futures(
|
|||||||
|
|
||||||
conversation.extend(sub_messages)
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
_postprocess_messages(conversation)
|
||||||
|
|
||||||
return conversation, mm_tracker.all_mm_data()
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
def apply_chat_template(
|
def apply_hf_chat_template(
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
*,
|
*,
|
||||||
tokenize: bool = False, # Different from HF's default
|
tokenize: bool = False, # Different from HF's default
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[str, List[int]]:
|
) -> str:
|
||||||
if chat_template is None and tokenizer.chat_template is None:
|
if chat_template is None and tokenizer.chat_template is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"As of transformers v4.44, default chat template is no longer "
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
"allowed, so you must provide a chat template if the tokenizer "
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
"does not define one.")
|
"does not define one.")
|
||||||
|
|
||||||
# per the Transformers docs & maintainers, tool call arguments in
|
return tokenizer.apply_chat_template(
|
||||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
# this is how tool-use chat templates will expect them moving forwards
|
|
||||||
# so, for messages that have tool_calls, parse the string (which we get
|
|
||||||
# from openAI format) to dict
|
|
||||||
for message in conversation:
|
|
||||||
if (message["role"] == "assistant" and "tool_calls" in message
|
|
||||||
and isinstance(message["tool_calls"], list)):
|
|
||||||
|
|
||||||
for i in range(len(message["tool_calls"])):
|
|
||||||
args: str = message["tool_calls"][i]["function"]["arguments"]
|
|
||||||
parsed_args: Dict = json.loads(args)
|
|
||||||
message["tool_calls"][i]["function"]["arguments"] = parsed_args
|
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
conversation=conversation,
|
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
tokenize=tokenize,
|
tokenize=tokenize,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
def apply_mistral_chat_template(
|
||||||
|
tokenizer: MistralTokenizer,
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[int]:
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages=messages,
|
||||||
|
chat_template=chat_template,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
@ -6,7 +6,8 @@ from tqdm import tqdm
|
|||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
apply_chat_template,
|
apply_hf_chat_template,
|
||||||
|
apply_mistral_chat_template,
|
||||||
parse_chat_messages)
|
parse_chat_messages)
|
||||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
@ -393,12 +394,21 @@ class LLM:
|
|||||||
conversation, mm_data = parse_chat_messages(messages, model_config,
|
conversation, mm_data = parse_chat_messages(messages, model_config,
|
||||||
tokenizer)
|
tokenizer)
|
||||||
|
|
||||||
prompt = apply_chat_template(
|
prompt: Union[str, List[int]]
|
||||||
tokenizer,
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
conversation,
|
prompt = apply_mistral_chat_template(
|
||||||
chat_template=chat_template,
|
tokenizer,
|
||||||
add_generation_prompt=add_generation_prompt,
|
messages=messages,
|
||||||
)
|
chat_template=chat_template,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = apply_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=chat_template,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
inputs: PromptInputs
|
inputs: PromptInputs
|
||||||
if is_list_of(prompt, int):
|
if is_list_of(prompt, int):
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from fastapi import Request
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
apply_chat_template,
|
apply_hf_chat_template,
|
||||||
|
apply_mistral_chat_template,
|
||||||
load_chat_template,
|
load_chat_template,
|
||||||
parse_chat_messages_futures)
|
parse_chat_messages_futures)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
|
|||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool.model_dump() for tool in request.tools
|
tool.model_dump() for tool in request.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt = apply_chat_template(
|
prompt: Union[str, List[int]]
|
||||||
tokenizer,
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
conversation=conversation,
|
prompt = apply_mistral_chat_template(
|
||||||
chat_template=request.chat_template or self.chat_template,
|
tokenizer,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
messages=request.messages,
|
||||||
tools=tool_dicts,
|
chat_template=request.chat_template or self.chat_template,
|
||||||
documents=request.documents,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
**(request.chat_template_kwargs or {}),
|
tools=tool_dicts,
|
||||||
)
|
documents=request.documents,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = apply_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in applying chat template from request: %s", e)
|
logger.error("Error in applying chat template from request: %s", e)
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -307,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# Send response to echo the input portion of the
|
# Send response to echo the input portion of the
|
||||||
# last message
|
# last message
|
||||||
if request.echo:
|
if request.echo:
|
||||||
last_msg_content: Optional[str] = ""
|
last_msg_content: str = ""
|
||||||
if conversation and conversation[-1].get(
|
if conversation and "content" in conversation[
|
||||||
"content") and conversation[-1].get(
|
-1] and conversation[-1].get("role") == role:
|
||||||
"role") == role:
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
last_msg_content = conversation[-1]["content"]
|
|
||||||
|
|
||||||
if last_msg_content:
|
if last_msg_content:
|
||||||
for i in range(num_choices):
|
for i in range(num_choices):
|
||||||
@ -659,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
last_msg_content = ""
|
last_msg_content = ""
|
||||||
if conversation and conversation[-1].get(
|
if conversation and "content" in conversation[-1] and conversation[
|
||||||
"content") and conversation[-1].get("role") == role:
|
-1].get("role") == role:
|
||||||
last_msg_content = conversation[-1]["content"] or ""
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
|
|||||||
@ -2,7 +2,8 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
from vllm.entrypoints.chat_utils import (apply_chat_template,
|
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||||
|
apply_mistral_chat_template,
|
||||||
load_chat_template,
|
load_chat_template,
|
||||||
parse_chat_messages_futures)
|
parse_chat_messages_futures)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
|||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
prompt: Union[str, List[int]]
|
||||||
if isinstance(request, TokenizeChatRequest):
|
if isinstance(request, TokenizeChatRequest):
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Multi-modal inputs are ignored during tokenization")
|
"Multi-modal inputs are ignored during tokenization")
|
||||||
|
|
||||||
prompt = apply_chat_template(
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
tokenizer,
|
prompt = apply_mistral_chat_template(
|
||||||
conversation=conversation,
|
tokenizer,
|
||||||
chat_template=self.chat_template,
|
messages=request.messages,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
chat_template=self.chat_template,
|
||||||
)
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = apply_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
|||||||
Tekkenizer)
|
Tekkenizer)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.entrypoints.chat_utils import ConversationMessage
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -122,19 +122,19 @@ class MistralTokenizer:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def encode(self, prompt: str) -> List[int]:
|
def encode(self, prompt: str) -> List[int]:
|
||||||
# `encode ` should only be used for prompt completion
|
# `encode` should only be used for prompt completion
|
||||||
# it should never be used for chat_completion.
|
# it should never be used for chat_completion.
|
||||||
# For chat completion use `apply_chat_template`
|
# For chat completion use `apply_chat_template`
|
||||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||||
|
|
||||||
def apply_chat_template(self,
|
def apply_chat_template(self,
|
||||||
conversation: List["ConversationMessage"],
|
messages: List["ChatCompletionMessageParam"],
|
||||||
tools: Optional[Dict[str, Any]] = None,
|
tools: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs) -> List[int]:
|
**kwargs) -> List[int]:
|
||||||
assert tools is None, "`tools` are not yet supported."
|
assert tools is None, "`tools` are not yet supported."
|
||||||
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
messages=conversation) # type: ignore[type-var]
|
messages=messages) # type: ignore[type-var]
|
||||||
encoded = self.mistral.encode_chat_completion(request)
|
encoded = self.mistral.encode_chat_completion(request)
|
||||||
|
|
||||||
# encode-decode to get clean prompt
|
# encode-decode to get clean prompt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user