mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:05:53 +08:00
Add trust_remote_code arg to get_config (#405)
This commit is contained in:
parent
b6fbb9a565
commit
ddfdf470ae
@ -54,7 +54,7 @@ class ModelConfig:
|
|||||||
self.use_dummy_weights = use_dummy_weights
|
self.use_dummy_weights = use_dummy_weights
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
self.hf_config = get_config(model)
|
self.hf_config = get_config(model, trust_remote_code)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,21 @@ _CONFIG_REGISTRY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_config(model: str) -> PretrainedConfig:
|
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
||||||
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
|
try:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model, trust_remote_code=trust_remote_code)
|
||||||
|
except ValueError as e:
|
||||||
|
if (not trust_remote_code and
|
||||||
|
"requires you to execute the configuration file" in str(e)):
|
||||||
|
err_msg = (
|
||||||
|
"Failed to load the model config. If the model is a custom "
|
||||||
|
"model not yet available in the HuggingFace transformers "
|
||||||
|
"library, consider setting `trust_remote_code=True` in LLM "
|
||||||
|
"or using the `--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
if config.model_type in _CONFIG_REGISTRY:
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
config = config_class.from_pretrained(model)
|
config = config_class.from_pretrained(model)
|
||||||
|
|||||||
@ -34,8 +34,8 @@ def get_tokenizer(
|
|||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
*args,
|
*args,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||||
@ -47,13 +47,14 @@ def get_tokenizer(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# If the error pertains to the tokenizer class not existing or not
|
# If the error pertains to the tokenizer class not existing or not
|
||||||
# currently being imported, suggest using the --trust-remote-code flag.
|
# currently being imported, suggest using the --trust-remote-code flag.
|
||||||
if (e is not None and
|
if (not trust_remote_code and
|
||||||
("does not exist or is not currently imported." in str(e)
|
("does not exist or is not currently imported." in str(e)
|
||||||
or "requires you to execute the tokenizer file" in str(e))):
|
or "requires you to execute the tokenizer file" in str(e))):
|
||||||
err_msg = (
|
err_msg = (
|
||||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||||
"tokenizer not yet available in the HuggingFace transformers "
|
"tokenizer not yet available in the HuggingFace transformers "
|
||||||
"library, consider using the --trust-remote-code flag.")
|
"library, consider setting `trust_remote_code=True` in LLM "
|
||||||
|
"or using the `--trust-remote-code` flag in the CLI.")
|
||||||
raise RuntimeError(err_msg) from e
|
raise RuntimeError(err_msg) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user