mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-21 01:57:03 +08:00
234 lines
7.7 KiB
Python
234 lines
7.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import importlib.util
|
|
from collections.abc import Callable
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, TypeVar, overload
|
|
|
|
import huggingface_hub
|
|
from typing_extensions import assert_never
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
from vllm.transformers_utils.gguf_utils import (
|
|
check_gguf_file,
|
|
get_gguf_file_path_from_hf,
|
|
is_gguf,
|
|
is_remote_gguf,
|
|
split_remote_gguf,
|
|
)
|
|
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
|
|
|
from .protocol import TokenizerLike
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import ModelConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_T = TypeVar("_T", bound=type[TokenizerLike])
|
|
|
|
|
|
class TokenizerRegistry:
|
|
# Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class)
|
|
REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {}
|
|
|
|
# In-tree tokenizers
|
|
@staticmethod
|
|
@overload
|
|
def register(tokenizer_mode: str) -> Callable[[_T], _T]: ...
|
|
|
|
# OOT tokenizers
|
|
@staticmethod
|
|
@overload
|
|
def register(tokenizer_mode: str, module: str, class_name: str) -> None: ...
|
|
|
|
@staticmethod
|
|
def register(
|
|
tokenizer_mode: str,
|
|
module: str | None = None,
|
|
class_name: str | None = None,
|
|
) -> Callable[[_T], _T] | None:
|
|
# In-tree tokenizers
|
|
if module is None or class_name is None:
|
|
|
|
def wrapper(tokenizer_cls: _T) -> _T:
|
|
assert tokenizer_mode not in TokenizerRegistry.REGISTRY
|
|
TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls
|
|
|
|
return tokenizer_cls
|
|
|
|
return wrapper
|
|
|
|
# OOT tokenizers
|
|
if tokenizer_mode in TokenizerRegistry.REGISTRY:
|
|
logger.warning(
|
|
"%s.%s is already registered for tokenizer_mode=%r. "
|
|
"It is overwritten by the new one.",
|
|
module,
|
|
class_name,
|
|
tokenizer_mode,
|
|
)
|
|
|
|
TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name)
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike":
|
|
if tokenizer_mode not in TokenizerRegistry.REGISTRY:
|
|
raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")
|
|
|
|
item = TokenizerRegistry.REGISTRY[tokenizer_mode]
|
|
if isinstance(item, type):
|
|
return item.from_pretrained(*args, **kwargs)
|
|
|
|
module, class_name = item
|
|
logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
|
|
|
|
class_ = resolve_obj_by_qualname(f"{module}.{class_name}")
|
|
return class_.from_pretrained(*args, **kwargs)
|
|
|
|
|
|
def get_tokenizer(
|
|
tokenizer_name: str | Path,
|
|
*args,
|
|
tokenizer_mode: str = "auto",
|
|
trust_remote_code: bool = False,
|
|
revision: str | None = None,
|
|
download_dir: str | None = None,
|
|
**kwargs,
|
|
) -> TokenizerLike:
|
|
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
|
|
if envs.VLLM_USE_MODELSCOPE:
|
|
# download model from ModelScope hub,
|
|
# lazy import so that modelscope is not required for normal use.
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
|
|
# avoid circular import
|
|
from vllm.model_executor.model_loader.weight_utils import get_lock
|
|
|
|
# Only set the tokenizer here, model will be downloaded on the workers.
|
|
if not Path(tokenizer_name).exists():
|
|
# Use file lock to prevent multiple processes from
|
|
# downloading the same file at the same time.
|
|
with get_lock(tokenizer_name, download_dir):
|
|
tokenizer_path = snapshot_download(
|
|
model_id=str(tokenizer_name),
|
|
cache_dir=download_dir,
|
|
revision=revision,
|
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
# Ignore weights - we only need the tokenizer.
|
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
|
)
|
|
tokenizer_name = tokenizer_path
|
|
|
|
if tokenizer_mode == "slow":
|
|
if kwargs.get("use_fast", False):
|
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
|
|
|
tokenizer_mode = "hf"
|
|
kwargs["use_fast"] = False
|
|
|
|
if "truncation_side" not in kwargs:
|
|
kwargs["truncation_side"] = "left"
|
|
|
|
# Separate model folder from file path for GGUF models
|
|
if is_gguf(tokenizer_name):
|
|
if check_gguf_file(tokenizer_name):
|
|
kwargs["gguf_file"] = Path(tokenizer_name).name
|
|
tokenizer_name = Path(tokenizer_name).parent
|
|
elif is_remote_gguf(tokenizer_name):
|
|
tokenizer_name, quant_type = split_remote_gguf(tokenizer_name)
|
|
# Get the HuggingFace Hub path for the GGUF file
|
|
gguf_file = get_gguf_file_path_from_hf(
|
|
tokenizer_name,
|
|
quant_type,
|
|
revision=revision,
|
|
)
|
|
kwargs["gguf_file"] = gguf_file
|
|
|
|
# Try to use official Mistral tokenizer if possible
|
|
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
|
|
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
|
|
files_list = list_filtered_repo_files(
|
|
model_name_or_path=str(tokenizer_name),
|
|
allow_patterns=allow_patterns,
|
|
revision=revision,
|
|
)
|
|
if len(files_list) > 0:
|
|
tokenizer_mode = "mistral"
|
|
|
|
# Fallback to HF tokenizer
|
|
if tokenizer_mode == "auto":
|
|
tokenizer_mode = "hf"
|
|
|
|
tokenizer_args = (tokenizer_name, *args)
|
|
tokenizer_kwargs = dict(
|
|
trust_remote_code=trust_remote_code,
|
|
revision=revision,
|
|
download_dir=download_dir,
|
|
**kwargs,
|
|
)
|
|
|
|
if tokenizer_mode == "custom":
|
|
logger.warning_once(
|
|
"TokenizerRegistry now uses `tokenizer_mode` as the registry key "
|
|
"instead of `tokenizer_name`. "
|
|
"Please update the definition of `.from_pretrained` in "
|
|
"your custom tokenizer to accept `args=%s`, `kwargs=%s`. "
|
|
"Then, you can pass `tokenizer_mode=%r` instead of "
|
|
"`tokenizer_mode='custom'` when initializing vLLM.",
|
|
tokenizer_args,
|
|
str(tokenizer_kwargs),
|
|
tokenizer_mode,
|
|
)
|
|
|
|
tokenizer_mode = str(tokenizer_name)
|
|
|
|
tokenizer = TokenizerRegistry.get_tokenizer(
|
|
tokenizer_mode,
|
|
*tokenizer_args,
|
|
**tokenizer_kwargs,
|
|
)
|
|
if not tokenizer.is_fast:
|
|
logger.warning(
|
|
"Using a slow tokenizer. This might cause a significant "
|
|
"slowdown. Consider using a fast tokenizer instead."
|
|
)
|
|
|
|
return tokenizer
|
|
|
|
|
|
cached_get_tokenizer = lru_cache(get_tokenizer)
|
|
|
|
|
|
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
|
return cached_get_tokenizer(
|
|
model_config.tokenizer,
|
|
tokenizer_mode=model_config.tokenizer_mode,
|
|
revision=model_config.tokenizer_revision,
|
|
trust_remote_code=model_config.trust_remote_code,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def init_tokenizer_from_config(model_config: "ModelConfig"):
|
|
runner_type = model_config.runner_type
|
|
if runner_type == "generate" or runner_type == "draft":
|
|
truncation_side = "left"
|
|
elif runner_type == "pooling":
|
|
truncation_side = "right"
|
|
else:
|
|
assert_never(runner_type)
|
|
|
|
return get_tokenizer(
|
|
model_config.tokenizer,
|
|
tokenizer_mode=model_config.tokenizer_mode,
|
|
trust_remote_code=model_config.trust_remote_code,
|
|
revision=model_config.tokenizer_revision,
|
|
truncation_side=truncation_side,
|
|
)
|