[Bugfix] Access get_vocab instead of vocab in tool parsers (#9188)

This commit is contained in:
Cyrus Leung 2024-10-09 22:59:57 +08:00 committed by GitHub
parent 21906a6f50
commit cfaa6008e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 6 deletions

View File

@ -1,6 +1,7 @@
import importlib
import importlib.util
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@ -29,6 +30,12 @@ class ToolParser:
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""

View File

@ -50,10 +50,9 @@ class Hermes2ProToolParser(ToolParser):
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab.get(
self.tool_call_start_token, None)
self.tool_call_end_token_id: int = self.model_tokenizer.vocab.get(
self.tool_call_end_token, None)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "

View File

@ -61,8 +61,7 @@ class MistralToolParser(ToolParser):
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.get_vocab().get(
self.bot_token, None)
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
raise RuntimeError(