[Tokenizer] Add tokenizer mode (#298)

This commit is contained in:
Woosuk Kwon 2023-06-28 14:19:22 -07:00 committed by GitHub
parent 425040d4c1
commit 998d9d1509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 5 deletions

View File

@ -17,6 +17,8 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
@ -31,7 +33,8 @@ class ModelConfig:
def __init__(
self,
model: str,
tokenizer: Optional[str],
tokenizer: str,
tokenizer_mode: str,
download_dir: Optional[str],
use_np_weights: bool,
use_dummy_weights: bool,
@ -40,6 +43,7 @@ class ModelConfig:
) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.download_dir = download_dir
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
@ -47,6 +51,15 @@ class ModelConfig:
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def verify_with_parallel_config(
self,

View File

@ -12,6 +12,7 @@ class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = "auto"
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
@ -42,6 +43,12 @@ class EngineArgs:
help='name or path of the huggingface model to use')
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode', type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--download-dir', type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
@ -109,8 +116,8 @@ class EngineArgs:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(
self.model, self.tokenizer, self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype, self.seed)
self.model, self.tokenizer, self.tokenizer_mode, self.download_dir,
self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,

View File

@ -61,6 +61,7 @@ class LLMEngine:
"Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
@ -77,7 +78,8 @@ class LLMEngine:
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.tokenizer)
self.tokenizer = get_tokenizer(model_config.tokenizer,
model_config.tokenizer_mode)
self.seq_counter = Counter()
# Create the parallel GPU workers.

View File

@ -26,6 +26,8 @@ class LLM:
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
@ -40,6 +42,7 @@ class LLM:
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
tensor_parallel_size: int = 1,
dtype: str = "auto",
seed: int = 0,
@ -50,6 +53,7 @@ class LLM:
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,

View File

@ -313,7 +313,7 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode)
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -13,10 +13,17 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
tokenizer_mode: str = "auto",
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True):
logger.info(
"For some LLaMA-based models, initializing the fast tokenizer may "