Add trust_remote_code arg to get_config (#405)

This commit is contained in:
Woosuk Kwon 2023-07-08 15:24:17 -07:00 committed by GitHub
parent b6fbb9a565
commit ddfdf470ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 6 deletions

View File

@ -54,7 +54,7 @@ class ModelConfig:
self.use_dummy_weights = use_dummy_weights self.use_dummy_weights = use_dummy_weights
self.seed = seed self.seed = seed
self.hf_config = get_config(model) self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode() self._verify_tokenizer_mode()

View File

@ -7,8 +7,21 @@ _CONFIG_REGISTRY = {
} }
def get_config(model: str) -> PretrainedConfig: def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True) try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model) config = config_class.from_pretrained(model)

View File

@ -34,8 +34,8 @@ def get_tokenizer(
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
trust_remote_code=trust_remote_code,
*args, *args,
trust_remote_code=trust_remote_code,
**kwargs) **kwargs)
except TypeError as e: except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments. # The LLaMA tokenizer causes a protobuf error in some environments.
@ -47,13 +47,14 @@ def get_tokenizer(
except ValueError as e: except ValueError as e:
# If the error pertains to the tokenizer class not existing or not # If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag. # currently being imported, suggest using the --trust-remote-code flag.
if (e is not None and if (not trust_remote_code and
("does not exist or is not currently imported." in str(e) ("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))): or "requires you to execute the tokenizer file" in str(e))):
err_msg = ( err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom " "Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers " "tokenizer not yet available in the HuggingFace transformers "
"library, consider using the --trust-remote-code flag.") "library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e raise RuntimeError(err_msg) from e
else: else:
raise e raise e