Add trust-remote-code flag to handle remote tokenizers (#364)

This commit is contained in:
codethazine 2023-07-07 20:04:58 +02:00 committed by GitHub
parent be54f8e5c4
commit a945fcc2ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 6 deletions

View File

@ -20,6 +20,8 @@ class ModelConfig:
tokenizer: Name or path of the huggingface tokenizer to use. tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer. available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface. default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading. use_np_weights: Save a numpy copy of model weights for faster loading.
@ -36,6 +38,7 @@ class ModelConfig:
model: str, model: str,
tokenizer: str, tokenizer: str,
tokenizer_mode: str, tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str], download_dir: Optional[str],
use_np_weights: bool, use_np_weights: bool,
use_dummy_weights: bool, use_dummy_weights: bool,
@ -45,6 +48,7 @@ class ModelConfig:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir self.download_dir = download_dir
self.use_np_weights = use_np_weights self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights self.use_dummy_weights = use_dummy_weights

View File

@ -13,6 +13,7 @@ class EngineArgs:
model: str model: str
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
tokenizer_mode: str = 'auto' tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
use_dummy_weights: bool = False use_dummy_weights: bool = False
@ -55,6 +56,9 @@ class EngineArgs:
help='tokenizer mode. "auto" will use the fast ' help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will ' 'tokenizer if available, and "slow" will '
'always use the slow tokenizer.') 'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
@ -141,9 +145,10 @@ class EngineArgs:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs. # Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.download_dir, self.tokenizer_mode, self.trust_remote_code,
self.use_np_weights, self.use_dummy_weights, self.download_dir, self.use_np_weights,
self.dtype, self.seed) self.use_dummy_weights, self.dtype,
self.seed)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)

View File

@ -62,6 +62,7 @@ class LLMEngine:
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, " f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
@ -78,7 +79,9 @@ class LLMEngine:
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.

View File

@ -28,6 +28,8 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer. tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer. if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
@ -43,6 +45,7 @@ class LLM:
model: str, model: str,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: str = "auto",
seed: int = 0, seed: int = 0,
@ -54,6 +57,7 @@ class LLM:
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
seed=seed, seed=seed,

View File

@ -15,6 +15,7 @@ def get_tokenizer(
tokenizer_name: str, tokenizer_name: str,
*args, *args,
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
@ -31,8 +32,11 @@ def get_tokenizer(
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.") "tokenizer.")
try: try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, tokenizer = AutoTokenizer.from_pretrained(
**kwargs) tokenizer_name,
trust_remote_code=trust_remote_code,
*args,
**kwargs)
except TypeError as e: except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments. # The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = ( err_msg = (
@ -40,6 +44,19 @@ def get_tokenizer(
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original " f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.") "tokenizer.")
raise RuntimeError(err_msg) from e raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (e is not None and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider using the --trust-remote-code flag.")
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast): if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning( logger.warning(