mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:15:01 +08:00
Enable LLaMA fast tokenizer (#132)
This commit is contained in:
parent
56b7f0efa4
commit
337871c6fd
@ -129,7 +129,7 @@ class SamplingParams:
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k},"
|
||||
f"top_k={self.top_k}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
|
||||
@ -7,11 +7,7 @@ from cacheflow.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
||||
# LLaMA fast tokenizer has a bug related to protobuf.
|
||||
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
||||
"llama",
|
||||
]
|
||||
_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
@ -20,7 +16,15 @@ def get_tokenizer(
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
||||
if config.model_type == "llama" and getattr(kwargs, "use_fast", True):
|
||||
# LLaMA fast tokenizer causes protobuf errors in some environments.
|
||||
# However, we found that the below LLaMA fast tokenizer works well in
|
||||
# most environments.
|
||||
model_name = "hf-internal-testing/llama-tokenizer"
|
||||
logger.info(
|
||||
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
|
||||
"potential protobuf errors.")
|
||||
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
||||
if getattr(kwargs, "use_fast", False) == True:
|
||||
raise ValueError(
|
||||
f"Cannot use the fast tokenizer for {config.model_type} due to "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user