diff --git a/tests/test_inputs.py b/tests/test_inputs.py index c4339827de8b6..8351af2528e4b 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,7 +7,7 @@ from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor -from vllm.tokenizers import init_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config pytestmark = pytest.mark.cpu_test @@ -108,7 +108,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) - tokenizer = init_tokenizer_from_config(model_config) + tokenizer = cached_tokenizer_from_config(model_config) input_preprocessor = InputPreprocessor(model_config, tokenizer) # HF processor adds sep token diff --git a/tests/tokenizers_/test_basic.py b/tests/tokenizers_/test_basic.py index b152227a5a50f..0510261eacde7 100644 --- a/tests/tokenizers_/test_basic.py +++ b/tests/tokenizers_/test_basic.py @@ -3,38 +3,39 @@ from typing import _get_protocol_attrs # type: ignore import pytest -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) from vllm.tokenizers import TokenizerLike, get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer def _get_missing_attrs(obj: object, target: type): return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)] +def _assert_tokenizer_like(tokenizer: object): + missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike) + assert not missing_attrs, f"Missing attrs: {missing_attrs}" + + def test_tokenizer_like_protocol(): - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer("gpt2", use_fast=False), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer("gpt2", use_fast=False) + assert isinstance(tokenizer, PreTrainedTokenizer) + _assert_tokenizer_like(tokenizer) - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer("gpt2", use_fast=True), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer("gpt2", use_fast=True) + assert isinstance(tokenizer, PreTrainedTokenizerFast) + _assert_tokenizer_like(tokenizer) - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer( - "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" - ), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer( + "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" + ) + assert isinstance(tokenizer, MistralTokenizer) + _assert_tokenizer_like(tokenizer) @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) diff --git a/tests/tokenizers_/test_registry.py b/tests/tokenizers_/test_registry.py index 7e795350d64c8..546f38b078dde 100644 --- a/tests/tokenizers_/test_registry.py +++ b/tests/tokenizers_/test_registry.py @@ -2,7 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from vllm.tokenizers import TokenizerLike, TokenizerRegistry, get_tokenizer +import pytest + +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.registry import ( + TokenizerRegistry, + get_tokenizer, + resolve_tokenizer_args, +) class TestTokenizer(TokenizerLike): @@ -40,10 +47,22 @@ class TestTokenizer(TokenizerLike): return True +@pytest.mark.parametrize("runner_type", ["generate", "pooling"]) +def test_resolve_tokenizer_args_idempotent(runner_type): + tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args( + "facebook/opt-125m", + runner_type=runner_type, + ) + + assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args( + tokenizer_name, *args, **kwargs + ) + + def test_customized_tokenizer(): TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__) - tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc") + tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc") assert isinstance(tokenizer, TestTokenizer) assert tokenizer.path_or_repo_id == "abc" assert tokenizer.bos_token_id == 0 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6a8dfe3cd9e38..8485022024a4f 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -50,7 +50,6 @@ from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector from vllm.tokenizers import TokenizerLike -from vllm.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import random_uuid @@ -60,6 +59,8 @@ from vllm.utils.import_utils import LazyLoader if TYPE_CHECKING: import torch + + from vllm.tokenizers.mistral import MistralTokenizer else: torch = LazyLoader("torch", globals(), "torch") @@ -1832,7 +1833,7 @@ def apply_hf_chat_template( def apply_mistral_chat_template( - tokenizer: MistralTokenizer, + tokenizer: "MistralTokenizer", messages: list[ChatCompletionMessageParam], chat_template: str | None, tools: list[dict[str, Any]] | None, diff --git a/vllm/tokenizers/__init__.py b/vllm/tokenizers/__init__.py index 67a6d7c8eb3d9..31e74b1a16e20 100644 --- a/vllm/tokenizers/__init__.py +++ b/vllm/tokenizers/__init__.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .deepseekv32 import DeepseekV32Tokenizer -from .hf import HfTokenizer -from .mistral import MistralTokenizer from .protocol import TokenizerLike from .registry import ( TokenizerRegistry, @@ -15,12 +12,9 @@ from .registry import ( __all__ = [ "TokenizerLike", - "HfTokenizer", - "MistralTokenizer", "TokenizerRegistry", "cached_get_tokenizer", "get_tokenizer", "cached_tokenizer_from_config", "init_tokenizer_from_config", - "DeepseekV32Tokenizer", ] diff --git a/vllm/tokenizers/deepseekv32.py b/vllm/tokenizers/deepseekv32.py index a7fa0f421725a..bf279a5cf67c5 100644 --- a/vllm/tokenizers/deepseekv32.py +++ b/vllm/tokenizers/deepseekv32.py @@ -2,24 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path +from typing import Any from transformers import BatchEncoding +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + from .deepseek_v32_encoding import encode_messages -from .hf import HfTokenizer, TokenizerLike -from .registry import TokenizerRegistry +from .hf import CachedHfTokenizer +from .protocol import TokenizerLike -@TokenizerRegistry.register("deepseek_v32") -class DeepseekV32Tokenizer(HfTokenizer): - def __init__(self, tokenizer: TokenizerLike): - self.tokenizer = tokenizer - self.name_or_path = ( - tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else "" - ) - self._added_vocab = self.tokenizer.get_added_vocab() - self._added_vocab_size = len(self._added_vocab) - +class DeepseekV32Tokenizer(CachedHfTokenizer): @classmethod def from_pretrained( cls, @@ -40,7 +34,21 @@ class DeepseekV32Tokenizer(HfTokenizer): ) return DeepseekV32Tokenizer(tokenizer) - def apply_chat_template(self, messages, tools=None, **kwargs): + def __init__(self, tokenizer: TokenizerLike) -> None: + super().__init__() + + self.tokenizer = tokenizer + self.name_or_path = getattr(tokenizer, "name_or_path", "") + + self._added_vocab = self.tokenizer.get_added_vocab() + self._added_vocab_size = len(self._added_vocab) + + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, + **kwargs, + ) -> str | list[int]: thinking = kwargs.get("thinking", False) thinking_mode = "thinking" if not thinking: @@ -49,13 +57,24 @@ class DeepseekV32Tokenizer(HfTokenizer): messages = conversation.copy() if tools is not None and len(tools) > 0: messages.insert(0, {"role": "system"}) - messages[0]["tools"] = tools + messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] # Historical reasoning content is dropped when a new user message is introduced drop_thinking = messages[-1]["role"] == "user" encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) prompt_str = encode_messages(messages, **encode_config) # type: ignore + + if kwargs.get("tokenize", True): + tokenizer_kwargs = { + k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs + } + return self.encode( + prompt_str, + add_special_tokens=False, + **tokenizer_kwargs, + ) + return prompt_str def num_special_tokens_to_add(self) -> int: diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index 3445073120387..a7b565dca5d8f 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -3,22 +3,18 @@ import contextlib import copy from pathlib import Path -from typing import TYPE_CHECKING +from typing import TypeAlias -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from .protocol import TokenizerLike -from .registry import TokenizerRegistry -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast -def get_cached_tokenizer( - tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast", -) -> TokenizerLike: +def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: """ By default, transformers will recompute multiple tokenizer properties each time they are called, leading to a significant slowdown. @@ -65,11 +61,10 @@ def get_cached_tokenizer( CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" cached_tokenizer.__class__ = CachedTokenizer - return cached_tokenizer # type: ignore + return cached_tokenizer -@TokenizerRegistry.register("hf") -class HfTokenizer(TokenizerLike): +class CachedHfTokenizer(TokenizerLike): @classmethod def from_pretrained( cls, @@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike): revision: str | None = None, download_dir: str | None = None, **kwargs, - ) -> "TokenizerLike": + ) -> HfTokenizer: try: tokenizer = AutoTokenizer.from_pretrained( path_or_repo_id, diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 1f44037dd55ec..534b0da484a5d 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -3,10 +3,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, cast +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.logger import init_logger from .protocol import TokenizerLike -from .registry import TokenizerRegistry if TYPE_CHECKING: from mistral_common.protocol.instruct.request import ( @@ -15,9 +16,6 @@ if TYPE_CHECKING: from mistral_common.tokens.tokenizers.tekken import Tekkenizer from transformers import BatchEncoding - from vllm.entrypoints.chat_utils import ChatCompletionMessageParam - from vllm.entrypoints.openai.protocol import ChatCompletionRequest - try: # Transformers v5 from transformers.tokenization_mistral_common import MistralCommonBackend @@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: return tokenizer.unk_id -@TokenizerRegistry.register("mistral") class MistralTokenizer(TokenizerLike): @classmethod def from_pretrained( diff --git a/vllm/tokenizers/protocol.py b/vllm/tokenizers/protocol.py index d6a3b0ba9b5f5..28754f9e10d00 100644 --- a/vllm/tokenizers/protocol.py +++ b/vllm/tokenizers/protocol.py @@ -97,7 +97,7 @@ class TokenizerLike(Protocol): messages: list["ChatCompletionMessageParam"], tools: list[dict[str, Any]] | None = None, **kwargs, - ) -> list[int]: + ) -> str | list[int]: raise NotImplementedError def convert_tokens_to_string(self, tokens: list[str]) -> str: diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index 1d44feeee500f..1296ce62ae693 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib.util -from collections.abc import Callable +from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, TypeVar, overload +from typing import TYPE_CHECKING import huggingface_hub -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never, deprecated import vllm.envs as envs from vllm.logger import init_logger @@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname from .protocol import TokenizerLike if TYPE_CHECKING: - from vllm.config import ModelConfig + from vllm.config.model import ModelConfig, RunnerType logger = init_logger(__name__) -_T = TypeVar("_T", bound=type[TokenizerLike]) + +_VLLM_TOKENIZERS = { + "deepseekv32": ("deepseekv32", "DeepseekV32Tokenizer"), + "hf": ("hf", "CachedHfTokenizer"), + "mistral": ("mistral", "MistralTokenizer"), +} -class TokenizerRegistry: - # Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class) - REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {} +@dataclass +class _TokenizerRegistry: + # Tokenizer mode -> (tokenizer module, tokenizer class) + tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict) - # In-tree tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str) -> Callable[[_T], _T]: ... - - # OOT tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str, module: str, class_name: str) -> None: ... - - @staticmethod - def register( - tokenizer_mode: str, - module: str | None = None, - class_name: str | None = None, - ) -> Callable[[_T], _T] | None: - # In-tree tokenizers - if module is None or class_name is None: - - def wrapper(tokenizer_cls: _T) -> _T: - assert tokenizer_mode not in TokenizerRegistry.REGISTRY - TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls - - return tokenizer_cls - - return wrapper - - # OOT tokenizers - if tokenizer_mode in TokenizerRegistry.REGISTRY: + def register(self, tokenizer_mode: str, module: str, class_name: str) -> None: + if tokenizer_mode in self.tokenizers: logger.warning( "%s.%s is already registered for tokenizer_mode=%r. " "It is overwritten by the new one.", @@ -72,36 +51,42 @@ class TokenizerRegistry: tokenizer_mode, ) - TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name) + self.tokenizers[tokenizer_mode] = (module, class_name) return None - @staticmethod - def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike": - if tokenizer_mode not in TokenizerRegistry.REGISTRY: + def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]: + if tokenizer_mode not in self.tokenizers: raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.") - item = TokenizerRegistry.REGISTRY[tokenizer_mode] - if isinstance(item, type): - return item.from_pretrained(*args, **kwargs) - - module, class_name = item + module, class_name = self.tokenizers[tokenizer_mode] logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}") - class_ = resolve_obj_by_qualname(f"{module}.{class_name}") - return class_.from_pretrained(*args, **kwargs) + return resolve_obj_by_qualname(f"{module}.{class_name}") + + def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike: + tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode) + return tokenizer_cls.from_pretrained(*args, **kwargs) -def get_tokenizer( +TokenizerRegistry = _TokenizerRegistry( + { + mode: (f"vllm.tokenizers.{mod_relname}", cls_name) + for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items() + } +) + + +def resolve_tokenizer_args( tokenizer_name: str | Path, *args, + runner_type: "RunnerType" = "generate", tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - revision: str | None = None, - download_dir: str | None = None, **kwargs, -) -> TokenizerLike: - """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" +): + revision: str | None = kwargs.get("revision") + download_dir: str | None = kwargs.get("download_dir") + if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -125,16 +110,6 @@ def get_tokenizer( ) tokenizer_name = tokenizer_path - if tokenizer_mode == "slow": - if kwargs.get("use_fast", False): - raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") - - tokenizer_mode = "hf" - kwargs["use_fast"] = False - - if "truncation_side" not in kwargs: - kwargs["truncation_side"] = "left" - # Separate model folder from file path for GGUF models if is_gguf(tokenizer_name): if check_gguf_file(tokenizer_name): @@ -150,6 +125,21 @@ def get_tokenizer( ) kwargs["gguf_file"] = gguf_file + if "truncation_side" not in kwargs: + if runner_type == "generate" or runner_type == "draft": + kwargs["truncation_side"] = "left" + elif runner_type == "pooling": + kwargs["truncation_side"] = "right" + else: + assert_never(runner_type) + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + + tokenizer_mode = "hf" + kwargs["use_fast"] = False + # Try to use official Mistral tokenizer if possible if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): allow_patterns = ["tekken.json", "tokenizer.model.v*"] @@ -165,49 +155,70 @@ def get_tokenizer( if tokenizer_mode == "auto": tokenizer_mode = "hf" - tokenizer_args = (tokenizer_name, *args) - tokenizer_kwargs = dict( + return tokenizer_mode, tokenizer_name, args, kwargs + + +cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args) + + +def tokenizer_args_from_config(config: "ModelConfig", **kwargs): + return cached_resolve_tokenizer_args( + config.tokenizer, + runner_type=config.runner_type, + tokenizer_mode=config.tokenizer_mode, + revision=config.tokenizer_revision, + trust_remote_code=config.trust_remote_code, + **kwargs, + ) + + +_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike) + + +def get_tokenizer( + tokenizer_name: str | Path, + *args, + tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment] + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, +) -> _T: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" + tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args( + tokenizer_name, + *args, trust_remote_code=trust_remote_code, revision=revision, download_dir=download_dir, **kwargs, ) - if tokenizer_mode == "custom": - logger.warning_once( - "TokenizerRegistry now uses `tokenizer_mode` as the registry key " - "instead of `tokenizer_name`. " - "Please update the definition of `.from_pretrained` in " - "your custom tokenizer to accept `args=%s`, `kwargs=%s`. " - "Then, you can pass `tokenizer_mode=%r` instead of " - "`tokenizer_mode='custom'` when initializing vLLM.", - tokenizer_args, - str(tokenizer_kwargs), - tokenizer_name, - ) + if tokenizer_cls == TokenizerLike: + tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode) + else: + tokenizer_cls_ = tokenizer_cls - tokenizer_mode = str(tokenizer_name) - - tokenizer = TokenizerRegistry.get_tokenizer( - tokenizer_mode, - *tokenizer_args, - **tokenizer_kwargs, - ) + tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs) if not tokenizer.is_fast: logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) - return tokenizer + return tokenizer # type: ignore cached_get_tokenizer = lru_cache(get_tokenizer) def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): + if model_config.skip_tokenizer_init: + return None + return cached_get_tokenizer( model_config.tokenizer, + runner_type=model_config.runner_type, tokenizer_mode=model_config.tokenizer_mode, revision=model_config.tokenizer_revision, trust_remote_code=model_config.trust_remote_code, @@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): ) +@deprecated( + "Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14." +) def init_tokenizer_from_config(model_config: "ModelConfig"): - runner_type = model_config.runner_type - if runner_type == "generate" or runner_type == "draft": - truncation_side = "left" - elif runner_type == "pooling": - truncation_side = "right" - else: - assert_never(runner_type) - - return get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=truncation_side, - ) + return cached_tokenizer_from_config(model_config) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 8745e1d9dbbbc..90af573535d3b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -60,17 +60,17 @@ def __getattr__(name: str): return cached_tokenizer_from_config if name == "init_tokenizer_from_configs": - from vllm.tokenizers import init_tokenizer_from_config + from vllm.tokenizers import cached_tokenizer_from_config warnings.warn( "`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` " - "has been moved to `vllm.tokenizers.init_tokenizer_from_config`. " + "has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. " "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) - return init_tokenizer_from_config + return cached_tokenizer_from_config raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 8eff61563ccea..a6ee241c41151 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,7 +26,7 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.usage.usage_lib import UsageContext @@ -111,7 +111,7 @@ class AsyncLLM(EngineClient): if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = init_tokenizer_from_config(self.model_config) + tokenizer = cached_tokenizer_from_config(self.model_config) self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4422eced82fea..1011317b706d3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,7 +23,7 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest @@ -86,7 +86,7 @@ class LLMEngine: if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = init_tokenizer_from_config(self.model_config) + tokenizer = cached_tokenizer_from_config(self.model_config) self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4dd478804049b..79ee4161e9dfa 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.tokenizers import init_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import ( @@ -71,7 +71,7 @@ class StructuredOutputManager: # of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_config( + self.tokenizer = cached_tokenizer_from_config( model_config=self.vllm_config.model_config ) reasoning_parser = (