Enable LLaMA fast tokenizer (#132)

This commit is contained in:
Woosuk Kwon 2023-05-28 02:51:42 -07:00 committed by GitHub
parent 56b7f0efa4
commit 337871c6fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 7 deletions

View File

@ -129,7 +129,7 @@ class SamplingParams:
f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k},"
f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, "
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "

View File

@ -7,11 +7,7 @@ from cacheflow.logger import init_logger
logger = init_logger(__name__)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama",
]
_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
def get_tokenizer(
@ -20,7 +16,15 @@ def get_tokenizer(
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if config.model_type == "llama" and getattr(kwargs, "use_fast", True):
# LLaMA fast tokenizer causes protobuf errors in some environments.
# However, we found that the below LLaMA fast tokenizer works well in
# most environments.
model_name = "hf-internal-testing/llama-tokenizer"
logger.info(
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
"potential protobuf errors.")
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if getattr(kwargs, "use_fast", False) == True:
raise ValueError(
f"Cannot use the fast tokenizer for {config.model_type} due to "