diff --git a/cacheflow/frontend/fastapi_frontend.py b/cacheflow/frontend/fastapi_frontend.py index e5c116d7b76d0..59e66a4ce5fe7 100644 --- a/cacheflow/frontend/fastapi_frontend.py +++ b/cacheflow/frontend/fastapi_frontend.py @@ -7,12 +7,12 @@ from typing import List, Dict, Optional from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse import ray -from transformers import AutoTokenizer import uvicorn from cacheflow.core.server import (Server, add_server_arguments, process_server_arguments, initialize_cluster) +from cacheflow.frontend.utils import get_tokenizer from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory @@ -44,7 +44,7 @@ class FastAPIServer: ): self.block_size = block_size - self.tokenizer = AutoTokenizer.from_pretrained(model) + self.tokenizer = get_tokenizer(model) self.seq_group_counter = Counter() self.seq_counter = Counter() if server_use_ray: diff --git a/cacheflow/frontend/simple_frontend.py b/cacheflow/frontend/simple_frontend.py index eca81c9a167eb..da3639530cd6f 100644 --- a/cacheflow/frontend/simple_frontend.py +++ b/cacheflow/frontend/simple_frontend.py @@ -1,8 +1,7 @@ import time from typing import List, Optional, Tuple -from transformers import AutoTokenizer - +from cacheflow.frontend.utils import get_tokenizer from cacheflow.logger import init_logger from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence, SequenceGroup @@ -21,7 +20,7 @@ class SimpleFrontend: ) -> None: self.block_size = block_size - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = get_tokenizer(model_name) self.seq_group_counter = Counter() self.seq_counter = Counter() self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = [] diff --git a/cacheflow/frontend/utils.py b/cacheflow/frontend/utils.py new file mode 100644 index 0000000000000..efb50d9aced63 --- /dev/null +++ b/cacheflow/frontend/utils.py @@ -0,0 +1,22 @@ +from typing import Union + +from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + + +_MODEL_TYPES_WITH_SLOW_TOKENIZER = [ + # LLaMA fast tokenizer has a bug related to protobuf. + # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554 + "llama", +] + + +def get_tokenizer( + model_name: str, + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + config = AutoConfig.from_pretrained(model_name) + if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: + kwargs["use_fast"] = False + return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)