add support for tokenizer revision (#1163)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Federico Cassano 2023-10-02 22:19:46 -04:00 committed by GitHub
parent ba0bfd40e2
commit 66d18a7fb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 23 additions and 1 deletions

View File

@ -41,6 +41,9 @@ class ModelConfig:
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default a tag name, or a commit id. If unspecified, will use the default
version. version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model. output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model quantization: Quantization method that was used to quantize the model
@ -58,6 +61,7 @@ class ModelConfig:
dtype: str, dtype: str,
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
) -> None: ) -> None:
@ -69,6 +73,7 @@ class ModelConfig:
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code, revision) self.hf_config = get_config(model, trust_remote_code, revision)

View File

@ -29,6 +29,7 @@ class EngineArgs:
max_num_seqs: int = 256 max_num_seqs: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
def __post_init__(self): def __post_init__(self):
@ -57,6 +58,13 @@ class EngineArgs:
help='the specific model version to use. It can be a branch ' help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
default=None,
help='the specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode', parser.add_argument('--tokenizer-mode',
type=str, type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
@ -175,7 +183,8 @@ class EngineArgs:
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision,
self.max_model_len, self.quantization) self.tokenizer_revision, self.max_model_len,
self.quantization)
cache_config = CacheConfig( cache_config = CacheConfig(
self.block_size, self.gpu_memory_utilization, self.swap_space, self.block_size, self.gpu_memory_utilization, self.swap_space,
getattr(model_config.hf_config, 'sliding_window', None)) getattr(model_config.hf_config, 'sliding_window', None))

View File

@ -75,6 +75,7 @@ class LLMEngine:
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, " f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
f"trust_remote_code={model_config.trust_remote_code}, " f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, " f"max_seq_len={model_config.max_model_len}, "
@ -98,6 +99,7 @@ class LLMEngine:
model_config.tokenizer, model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
tokenizer_revision=model_config.tokenizer_revision,
revision=model_config.revision) revision=model_config.revision)
self.seq_counter = Counter() self.seq_counter = Counter()

View File

@ -42,6 +42,8 @@ class LLM:
quantized and use `dtype` to determine the data type of the weights. quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher reserve for the model weights, activations, and KV cache. Higher
@ -65,6 +67,7 @@ class LLM:
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None, quantization: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0, seed: int = 0,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
@ -81,6 +84,7 @@ class LLM:
dtype=dtype, dtype=dtype,
quantization=quantization, quantization=quantization,
revision=revision, revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed, seed=seed,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space, swap_space=swap_space,

View File

@ -16,6 +16,7 @@ def get_tokenizer(
*args, *args,
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
@ -37,6 +38,7 @@ def get_tokenizer(
tokenizer_name, tokenizer_name,
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs) **kwargs)
except TypeError as e: except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments. # The LLaMA tokenizer causes a protobuf error in some environments.