mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 23:57:19 +08:00
[Bugfix] Access get_vocab instead of vocab in tool parsers (#9188)
This commit is contained in:
parent
21906a6f50
commit
cfaa6008e6
@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
@ -29,6 +30,12 @@ class ToolParser:
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> Dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
|
||||
@ -50,10 +50,9 @@ class Hermes2ProToolParser(ToolParser):
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_id: int = self.model_tokenizer.vocab.get(
|
||||
self.tool_call_start_token, None)
|
||||
self.tool_call_end_token_id: int = self.model_tokenizer.vocab.get(
|
||||
self.tool_call_end_token, None)
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
|
||||
@ -61,8 +61,7 @@ class MistralToolParser(ToolParser):
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.model_tokenizer.get_vocab().get(
|
||||
self.bot_token, None)
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
if not self.bot_token_id:
|
||||
raise RuntimeError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user