mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
Add trust-remote-code flag to handle remote tokenizers (#364)
This commit is contained in:
parent
be54f8e5c4
commit
a945fcc2ae
@ -20,6 +20,8 @@ class ModelConfig:
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
||||
@ -36,6 +38,7 @@ class ModelConfig:
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
use_np_weights: bool,
|
||||
use_dummy_weights: bool,
|
||||
@ -45,6 +48,7 @@ class ModelConfig:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.download_dir = download_dir
|
||||
self.use_np_weights = use_np_weights
|
||||
self.use_dummy_weights = use_dummy_weights
|
||||
|
||||
@ -13,6 +13,7 @@ class EngineArgs:
|
||||
model: str
|
||||
tokenizer: Optional[str] = None
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
@ -55,6 +56,9 @@ class EngineArgs:
|
||||
help='tokenizer mode. "auto" will use the fast '
|
||||
'tokenizer if available, and "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=EngineArgs.download_dir,
|
||||
@ -141,9 +145,10 @@ class EngineArgs:
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.download_dir,
|
||||
self.use_np_weights, self.use_dummy_weights,
|
||||
self.dtype, self.seed)
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.use_np_weights,
|
||||
self.use_dummy_weights, self.dtype,
|
||||
self.seed)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
|
||||
@ -62,6 +62,7 @@ class LLMEngine:
|
||||
f"model={model_config.model!r}, "
|
||||
f"tokenizer={model_config.tokenizer!r}, "
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
@ -78,7 +79,9 @@ class LLMEngine:
|
||||
self._verify_args()
|
||||
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
|
||||
@ -28,6 +28,8 @@ class LLM:
|
||||
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
||||
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||
if available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@ -43,6 +45,7 @@ class LLM:
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
seed: int = 0,
|
||||
@ -54,6 +57,7 @@ class LLM:
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
|
||||
@ -15,6 +15,7 @@ def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
@ -31,8 +32,11 @@ def get_tokenizer(
|
||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer.")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args,
|
||||
**kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
*args,
|
||||
**kwargs)
|
||||
except TypeError as e:
|
||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||
err_msg = (
|
||||
@ -40,6 +44,19 @@ def get_tokenizer(
|
||||
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
if (e is not None and
|
||||
("does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e))):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||
"tokenizer not yet available in the HuggingFace transformers "
|
||||
"library, consider using the --trust-remote-code flag.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user