From ddfdf470ae721d5be668af97d5a2b5d40ce4c15c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 8 Jul 2023 15:24:17 -0700 Subject: [PATCH] Add trust_remote_code arg to get_config (#405) --- vllm/config.py | 2 +- vllm/transformers_utils/config.py | 17 +++++++++++++++-- vllm/transformers_utils/tokenizer.py | 7 ++++--- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e00ba8b12060..c34bf2953634 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,7 +54,7 @@ class ModelConfig: self.use_dummy_weights = use_dummy_weights self.seed = seed - self.hf_config = get_config(model) + self.hf_config = get_config(model, trust_remote_code) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_tokenizer_mode() diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 866b23bff098..aeb4a6aacb53 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -7,8 +7,21 @@ _CONFIG_REGISTRY = { } -def get_config(model: str) -> PretrainedConfig: - config = AutoConfig.from_pretrained(model, trust_remote_code=True) +def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: + try: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code) + except ValueError as e: + if (not trust_remote_code and + "requires you to execute the configuration file" in str(e)): + err_msg = ( + "Failed to load the model config. If the model is a custom " + "model not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index c38e6cc1b014..116a109a4541 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -34,8 +34,8 @@ def get_tokenizer( try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, - trust_remote_code=trust_remote_code, *args, + trust_remote_code=trust_remote_code, **kwargs) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments. @@ -47,13 +47,14 @@ def get_tokenizer( except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. - if (e is not None and + if (not trust_remote_code and ("does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e))): err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " - "library, consider using the --trust-remote-code flag.") + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") raise RuntimeError(err_msg) from e else: raise e