Add trust_remote_code arg to get_config (#405)

This commit is contained in:
Woosuk Kwon 2023-07-08 15:24:17 -07:00 committed by GitHub
parent b6fbb9a565
commit ddfdf470ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 6 deletions

View File

@ -54,7 +54,7 @@ class ModelConfig:
self.use_dummy_weights = use_dummy_weights
self.seed = seed
self.hf_config = get_config(model)
self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()

View File

@ -7,8 +7,21 @@ _CONFIG_REGISTRY = {
}
def get_config(model: str) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)

View File

@ -34,8 +34,8 @@ def get_tokenizer(
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
@ -47,13 +47,14 @@ def get_tokenizer(
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (e is not None and
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider using the --trust-remote-code flag.")
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e