[Misc] Refactor tokenizer interface (#29693)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-29 20:02:21 +08:00 committed by GitHub
parent f223ed4181
commit 34a984274e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
119 changed files with 752 additions and 821 deletions

View File

@ -316,7 +316,7 @@ steps:
source_file_dependencies:
- vllm/
- tests/engine
- tests/tokenization
- tests/tokenizers_
- tests/test_sequence
- tests/test_config
- tests/test_logger
@ -324,7 +324,7 @@ steps:
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45

View File

@ -282,7 +282,7 @@ steps:
source_file_dependencies:
- vllm/
- tests/engine
- tests/tokenization
- tests/tokenizers_
- tests/test_sequence
- tests/test_config
- tests/test_logger
@ -290,7 +290,7 @@ steps:
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45

View File

@ -620,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
except ImportError as e:
raise ImportError(
"MistralTokenizer requires vllm package.\n"

View File

@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
# define a reasoning parser and register it to vllm
# the name list in register_module can be used
# in --reasoning-parser.
class ExampleParser(ReasoningParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def extract_reasoning_streaming(

View File

@ -422,7 +422,7 @@ Here is a summary of a plugin file:
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens

View File

@ -10,7 +10,7 @@ import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
@pytest.fixture()

View File

@ -4,9 +4,9 @@
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function")
def default_tokenizer() -> AnyTokenizer:
def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2")

View File

@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from ....utils import RemoteOpenAIServer
@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
@pytest.fixture
def qwen_tokenizer() -> AnyTokenizer:
def qwen_tokenizer() -> TokenizerLike:
from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer)
@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer,
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer,
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer,
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:

View File

@ -7,11 +7,11 @@ import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture
def parser(default_tokenizer: AnyTokenizer):
def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer)

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
@ -208,7 +208,7 @@ def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
@ -224,7 +224,7 @@ def test_tool_call(
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
@ -188,7 +188,7 @@ def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
@ -205,7 +205,7 @@ def test_tool_call(
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
@ -168,7 +168,7 @@ def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
@ -185,7 +185,7 @@ def test_tool_call(
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer

View File

@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
class StreamingToolReconstructor:
@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
return tool_parser.extract_tool_calls(model_output, request)
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]:
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
# Split a string into a series of deltas using the provided tokenizer. Each
# delta will be the string equivalent of a single token.
token_ids = tokenizer.encode(text, add_special_tokens=False)

View File

@ -28,8 +28,8 @@ from vllm.multimodal.utils import (
encode_image_base64,
encode_video_base64,
)
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH

View File

@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolParser,
)
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from ...utils import check_logprobs_close

View File

@ -9,7 +9,7 @@ from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer

View File

@ -9,7 +9,7 @@ import torch
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS
@ -33,7 +33,7 @@ def run_test(
auto_cls: type[_BaseAutoModelClass],
use_tokenizer_eos: bool,
comparator: Callable[..., None],
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None,
get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None,
stop_str: list[str] | None,
limit_mm_per_prompt: dict[str, int],
vllm_runner_kwargs: dict[str, Any] | None,

View File

@ -14,7 +14,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption
from vllm.logprobs import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .....conftest import (
AUDIO_ASSETS,
@ -126,7 +126,7 @@ class VLMTestInfo(NamedTuple):
vllm_runner_kwargs: dict[str, Any] | None = None
# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None = None
get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: list[str] | None = None

View File

@ -22,8 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens,
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from contextlib import nullcontext
from typing import cast
@ -23,7 +24,7 @@ from vllm.multimodal.processing import (
replace_token_matches,
)
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .utils import random_image
@ -238,7 +239,7 @@ def test_find_token_matches(
update_type,
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = {
key: update_type(key, target, []).resolve(0)
@ -385,7 +386,7 @@ def test_find_text_matches(
update_type,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = {
key: update_type(key, target, []).resolve(0)
@ -545,7 +546,7 @@ def test_find_update_text(
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
for (
update_type,
@ -750,7 +751,7 @@ def test_find_update_tokens(
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
for (
update_type,
@ -900,7 +901,7 @@ def test_find_mm_placeholders(
update_type,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
@ -1029,7 +1030,7 @@ def test_hf_processor_init_kwargs(
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
@ -1065,7 +1066,7 @@ def test_hf_processor_call_kwargs(
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
@ -1088,9 +1089,7 @@ def test_apply_matches_no_match_exits_quickly():
With the fix, it should exit immediately when no match is found.
"""
import time
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
# Create a long prompt with no placeholder
long_prompt = "x" * 10000

View File

@ -5,7 +5,7 @@ import pytest
from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
parser_name = "mistral"

View File

@ -4,7 +4,7 @@
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.reasoning import ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
class StreamingReasoningReconstructor:

View File

@ -1,18 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = ["BAAI/bge-base-en"]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2

View File

@ -1,32 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
{meth}`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118

View File

@ -1,23 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = [
"facebook/opt-125m",
"gpt2",
]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
def test_tokenizer_revision(tokenizer_name: str):
# Assume that "main" branch always exists
tokenizer = get_tokenizer(tokenizer_name, revision="main")
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match="not a valid git identifier"):
get_tokenizer(tokenizer_name, revision="never")

View File

@ -1,120 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TestTokenizer(TokenizerBase):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()
@property
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
@property
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def is_fast(self) -> bool:
raise NotImplementedError()
@property
def vocab_size(self) -> int:
raise NotImplementedError()
@property
def max_token_id(self) -> int:
raise NotImplementedError()
@property
def truncation_side(self) -> str:
raise NotImplementedError()
def __call__(
self,
text: str | list[str] | list[int],
text_pair: str | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: int | None = None,
):
raise NotImplementedError()
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: int | None = None,
) -> list[int]:
raise NotImplementedError()
def encode(self, text: str, add_special_tokens: bool | None = None) -> list[int]:
raise NotImplementedError()
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
raise NotImplementedError()
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
raise NotImplementedError()
def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer"
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

View File

@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE: Since CI runs the tests from the `tests` directory, it is necessary to rename
# this module to avoid conflicting with HF's `tokenizers` package

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import _get_protocol_attrs # type: ignore
import pytest
from transformers import PreTrainedTokenizerBase
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_tokenizer
def _get_missing_attrs(obj: object, target: type):
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
def test_tokenizer_like_protocol():
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer("gpt2", use_fast=False),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer("gpt2", use_fast=True),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer(
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str):
# Assume that "main" branch always exists
tokenizer = get_tokenizer(tokenizer_name, revision="main")
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match="not a valid git identifier"):
get_tokenizer(tokenizer_name, revision="never")
@pytest.mark.parametrize("tokenizer_name", ["BAAI/bge-base-en"])
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2

View File

@ -6,7 +6,8 @@ from copy import deepcopy
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
@ -25,7 +26,7 @@ def test_cached_tokenizer(model_id: str):
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
def _check_consistency(target: TokenizerLike, expected: TokenizerLike):
assert isinstance(target, type(expected))
# Cached attributes

View File

@ -8,7 +8,7 @@ import pytest
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (
FastIncrementalDetokenizer,

View File

@ -7,7 +7,7 @@ import pytest
from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.transformers_utils.tokenizers.mistral import (
from vllm.tokenizers.mistral import (
MistralTokenizer,
_prepare_apply_chat_template_tools_and_messages,
)
@ -308,25 +308,6 @@ class TestMistralTokenizer:
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
assert mistral_tokenizer.get_added_vocab() == {}
def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686]
)
assert mistral_tokenizer.encode_one("Hello world !") == token_ids
assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids
assert (
mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1)
== token_ids[:-2]
)
assert (
mistral_tokenizer.encode_one(
"Hello world !", truncation=False, max_length=1
)
== token_ids
)
assert mistral_tokenizer.encode_one("") == []
def test_encode(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[1, 22177, 4304, 2662]

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.tokenizers import TokenizerLike, TokenizerRegistry
from vllm.transformers_utils.tokenizer import get_tokenizer
class TestTokenizer(TokenizerLike):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer() # type: ignore
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer",
__name__,
TestTokenizer.__name__,
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

View File

@ -14,8 +14,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
@ -173,7 +174,7 @@ def test_extract_tool_calls(
def stream_delta_message_generator(
ernie45_tool_parser: Ernie45ToolParser,
ernie45_tokenizer: AnyTokenizer,
ernie45_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:

View File

@ -10,8 +10,9 @@ from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
@ -44,7 +45,9 @@ def assert_tool_calls(
def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str
jamba_tool_parser: JambaToolParser,
jamba_tokenizer: TokenizerLike,
model_output: str,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)

View File

@ -17,8 +17,9 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser,
)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
@ -104,7 +105,7 @@ def assert_tool_calls(
def stream_delta_message_generator(
qwen3_tool_parser,
qwen3_tokenizer: AnyTokenizer,
qwen3_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:

View File

@ -15,8 +15,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
@ -256,7 +257,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: AnyTokenizer,
seed_oss_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:

View File

@ -13,8 +13,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
@ -49,7 +50,7 @@ def assert_tool_calls(
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer,
xlam_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:

View File

@ -1,62 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, call, patch
def test_get_llama3_eos_token():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
import pytest
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
@pytest.mark.parametrize(
"allow_patterns,expected_relative_files",
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
@pytest.mark.parametrize(
"allow_patterns,expected_relative_files",
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)

View File

@ -18,7 +18,7 @@ from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import (
EngineCoreEvent,
EngineCoreEventType,
@ -31,7 +31,7 @@ from vllm.v1.metrics.stats import IterationStats, SchedulerStats
def _ref_convert_id_to_token(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_id: int,
) -> str:
"""Reference impl of logprobs detokenization.

View File

@ -27,8 +27,8 @@ ALLOWED_FILES = {
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py",
"tests/tokenizers_/test_cached_tokenizer.py",
"tests/utils_/test_hashing.py",
"tests/tokenization/test_cached_tokenizer.py",
"benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py",

View File

@ -35,6 +35,7 @@ FILES = [
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",

View File

@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import PlaceholderModule
try:
@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str:
# Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
lora_tokenizer_cache: dict[int, TokenizerLike] = {}
def process_image(image: Any) -> Mapping[str, Any]:

View File

@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
@ -85,7 +85,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def get_tokenizer(self) -> AnyTokenizer:
async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer"""
...

View File

@ -49,9 +49,9 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils.func_utils import supports_kw
@ -536,7 +536,7 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
super().__init__()
self._model_config = model_config
@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
@ -1624,7 +1624,7 @@ def parse_chat_messages(
def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],

View File

@ -71,11 +71,8 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
MistralTokenizer,
get_cached_tokenizer,
)
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
@ -350,11 +347,11 @@ class LLM:
self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
@ -1244,7 +1241,7 @@ class LLM:
def _embedding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None,
@ -1276,7 +1273,7 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None,

View File

@ -62,8 +62,9 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import (
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time())
@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing):
self,
logprobs: dict[int, Logprob],
top_logprobs: int | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]:
return [
@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs:

View File

@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin):
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):

View File

@ -33,7 +33,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: list[CompletionResponseChoice] = []
@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
initial_text_offset: int = 0,
return_as_token_id: bool | None = None,
) -> CompletionLogProbs:
@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)

View File

@ -7,13 +7,14 @@ import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
@dataclass(kw_only=True)
class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Sequence[RequestPrompt] | None = []
engine_prompts: list[EngineTokensPrompt] | None = []
model_config = ConfigDict(arbitrary_types_allowed=True)
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
class ResponseGenerationMixin(BaseModel):
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel):
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field(
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(
RequestProcessingMixin,
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
# Shared across all requests
request: RequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
# Shared across most requests
tokenizer: AnyTokenizer | None = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
tokenizer: TokenizerLike | None = None
ClassificationServeContext = ServeContext[ClassificationRequest]
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
@ -281,7 +266,7 @@ class OpenAIServing:
apply_mistral_chat_template, executor=self._tokenizer_executor
)
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor
@ -291,7 +276,7 @@ class OpenAIServing:
def _get_tool_parser(
self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
) -> Callable[[AnyTokenizer], ToolParser] | None:
) -> Callable[[TokenizerLike], ToolParser] | None:
"""Get the tool parser based on the name."""
parser = None
if not enable_auto_tools or tool_parser_name is None:
@ -317,7 +302,7 @@ class OpenAIServing:
def _get_reasoning_parser(
self,
reasoning_parser_name: str,
) -> Callable[[AnyTokenizer], ReasoningParser] | None:
) -> Callable[[TokenizerLike], ReasoningParser] | None:
"""Get the reasoning parser based on the name."""
parser = None
if not reasoning_parser_name:
@ -547,7 +532,7 @@ class OpenAIServing:
prompt_logprobs=None,
)
def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer:
def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
@ -877,7 +862,7 @@ class OpenAIServing:
self,
request: AnyRequest,
prompt: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
@ -919,7 +904,7 @@ class OpenAIServing:
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
@ -1015,7 +1000,7 @@ class OpenAIServing:
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TextTokensPrompt:
@ -1034,7 +1019,7 @@ class OpenAIServing:
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
@ -1079,7 +1064,7 @@ class OpenAIServing:
async def _preprocess_chat(
self,
request: ChatLikeRequest | ResponsesRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
@ -1088,13 +1073,18 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
@ -1370,9 +1360,9 @@ class OpenAIServing:
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]()
@ -1442,7 +1432,7 @@ class OpenAIServing:
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
return_as_token_id: bool = False,
) -> str:
if return_as_token_id:
@ -1450,6 +1440,12 @@ class OpenAIServing:
if logprob.decoded_token is not None:
return logprob.decoded_token
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: str | None) -> bool:

View File

@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
):
if request.tools is None or (
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int | None = None,
) -> ErrorResponse | ResponsesResponse:
@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
logprobs: dict[int, SampleLogprob],
top_logprobs: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary."""
out = []
@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
token_ids: Sequence[int],
logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
top_logprobs: int | None = None,
) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided"
@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
token_ids: Sequence[int],
logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
top_logprobs: int | None = None,
) -> list[response_text_delta_event.Logprob]:
lgs = self._create_response_logprobs(
@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
final_output: CompletionOutput,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> list[ResponseOutputItem]:
if self.reasoning_parser:
try:
@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[
@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
created_time: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]:

View File

@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__)
@ -60,7 +60,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
texts_1: list[str],
texts_2: list[str],
request: RerankRequest | ScoreRequest,
@ -153,7 +153,7 @@ class ServingScores(OpenAIServing):
def _preprocess_score(
self,
request: RerankRequest | ScoreRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
@ -175,7 +175,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest,

View File

@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing):
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
chat_template: str | None
def to_dict(self) -> dict[str, Any]:

View File

@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.sampling_params import (
StructuredOutputsParams,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path
@ -36,7 +36,7 @@ class ToolParser:
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
"""
Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []

View File

@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.bot_token = "<function_call>"

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"

View File

@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
if isinstance(tokenizer, MistralTokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []

View File

@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import consume_space
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0

View File

@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []

View File

@ -4,11 +4,11 @@
import regex as re
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_call_start_token: str = "<longcat_tool_call>"

View File

@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize streaming state for tracking tool call progress

View File

@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
@ -46,7 +46,7 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
@ -61,7 +61,7 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):

View File

@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
AnyTokenizer = object
TokenizerLike = object
logger = init_logger(__name__)
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "AnyTokenizer"):
def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer)
def extract_tool_calls(

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False

View File

@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()

View File

@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# --- streaming state ---

View File

@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<tool_sep>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming

View File

@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode

View File

@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@ -85,7 +85,7 @@ class BaseRenderer(ABC):
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
):
super().__init__()
self.model_config = model_config
@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None,
async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer]
tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None,
):
super().__init__(model_config, tokenizer)
@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer):
return async_tokenizer
tokenizer = self.tokenizer
if self.tokenizer is None:
if tokenizer is None:
raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:

View File

@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from vllm.transformers_utils.tokenizer import TokenizerLike
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False):
def _cosine_similarity(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
@ -93,7 +89,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@ -118,12 +114,14 @@ def _parse_score_content(
mm_tracker: BaseMultiModalItemTracker,
) -> _ContentPart | None:
if isinstance(data, str):
data = ChatCompletionContentPartTextParam(type="text", text=data)
part = ChatCompletionContentPartTextParam(type="text", text=data)
else:
part = data
mm_parser = mm_tracker.create_parser()
parse_res = _parse_chat_message_content_part(
data,
part,
mm_parser,
wrap_dicts=False,
interleave_strings=False,
@ -181,7 +179,7 @@ def post_process_tokens(
def get_score_prompt(
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,

View File

@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
@ -46,7 +46,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
@ -59,7 +59,7 @@ class InputPreprocessor:
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True"
@ -228,11 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> AnyTokenizer:
def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer()
return tokenizer

View File

@ -5,7 +5,7 @@ from typing import TypeAlias
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
LogitsProcessor: TypeAlias = (
Callable[[list[int], torch.Tensor], torch.Tensor]
@ -19,7 +19,7 @@ to sample from."""
def get_bad_words_logits_processors(
bad_words: list[str], tokenizer: AnyTokenizer
bad_words: list[str], tokenizer: TokenizerLike
) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list()

View File

@ -28,7 +28,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .intern_vit import InternVisionModel
from .internvl import (
@ -241,7 +241,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,

View File

@ -50,7 +50,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads
@ -347,7 +347,7 @@ class BaseInternVLProcessor(ABC):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
@ -561,7 +561,7 @@ class InternVLProcessor(BaseInternVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,

View File

@ -73,9 +73,9 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
cached_tokenizer_from_config,
encode_tokens,
)
@ -284,7 +284,7 @@ class BaseNanoNemotronVLProcessor(ABC):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*args,
max_num_tiles: int | None = None,
**kwargs,
@ -434,7 +434,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None,
@ -645,7 +645,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame: list[int],
frames_indices: list[int],
frame_duration_ms: int,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
img_start_token_ids: list[int],
img_end_token_ids: list[int],
img_context_token_ids: list[int],
@ -670,7 +670,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame (list[int]): number of tokens per frame
frames_indices (list[int]): frame indices
frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators
tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens

View File

@ -34,8 +34,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_image_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (
MultiModalEmbeddings,
@ -203,7 +203,7 @@ class NemotronVLProcessor(InternVLProcessor):
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
image_processor: BaseImageProcessorFast,
*,
min_dynamic_patch: int | None = None,

View File

@ -31,7 +31,7 @@ from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .qwen2_5_vl import (
Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
@ -79,7 +79,7 @@ class OpenCUAProcessor(Qwen2VLProcessor):
def __init__(
self,
vision_config: dict,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
**kwargs,
):
image_processor = Qwen2VLImageProcessor(**vision_config)

View File

@ -59,10 +59,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
)
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

View File

@ -91,7 +91,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
@ -1533,7 +1533,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
def __init__(
self,
vision_config: dict,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
**kwargs,
):
self.image_processor = Tarsier2ImageProcessor(**vision_config)

View File

@ -47,7 +47,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -282,7 +282,7 @@ class SkyworkR1VProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,

View File

@ -43,8 +43,8 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -321,7 +321,7 @@ class Step3VLProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> None:
super().__init__()

View File

@ -51,10 +51,8 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
)
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix

View File

@ -23,8 +23,9 @@ import torch
from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
@ -76,7 +77,7 @@ PromptSeq: TypeAlias = str | list[int]
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text: str,
*,
add_special_tokens: bool | None = None,
@ -86,7 +87,7 @@ def _cached_encode(
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_ids: tuple[int, ...],
*,
skip_special_tokens: bool | None = None,
@ -96,14 +97,14 @@ def _cached_decode(
)
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
if isinstance(seq, str):
return seq
return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
if isinstance(seq, str):
return _cached_encode(tokenizer, seq, add_special_tokens=False)
@ -113,7 +114,7 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
class _GetMatchIndex(Protocol):
def __call__(
self,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt: PromptSeq,
start_idx: int = 0,
) -> int | None: ...
@ -143,7 +144,7 @@ class PromptIndexTargets:
"""
def get_match_index(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt: PromptSeq,
start_idx: int = 0,
) -> int | None:
@ -199,7 +200,7 @@ class PromptUpdateDetails(Generic[_S]):
full: _S
"""The full content."""
is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None
is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
"""
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions
@ -220,7 +221,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_text: str,
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text)
token_ids = _seq2tokens(tokenizer, full)
@ -236,7 +237,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_token_id: int,
) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id
@ -522,7 +523,7 @@ class ResolvedPromptUpdate:
def iter_token_matches(
self,
prompt: list[int],
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@ -544,7 +545,7 @@ class ResolvedPromptUpdate:
def iter_text_matches(
self,
prompt: str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@ -566,7 +567,7 @@ class ResolvedPromptUpdate:
def iter_matches(
self,
prompt: list[int] | str,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
start_idx: int = 0,
) -> Generator[PromptTargetMatch]:
@ -675,7 +676,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
def _find_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
*,
prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult",
@ -740,7 +741,7 @@ def _all_items_found(
def _apply_matches(
prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
@ -806,7 +807,7 @@ def _apply_matches(
def apply_token_matches(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
"""
Apply the updates in `mm_prompt_updates` to `prompt`.
@ -823,7 +824,7 @@ def apply_token_matches(
def apply_text_matches(
prompt: str,
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
"""
Apply the updates in `mm_prompt_updates` to `prompt`.
@ -840,7 +841,7 @@ def apply_text_matches(
def _iter_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> Iterable[PlaceholderFeaturesInfo]:
"""
Yield each set of placeholder tokens found in `prompt`.
@ -909,7 +910,7 @@ def _iter_placeholders(
def find_mm_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it))
@ -930,7 +931,7 @@ class InputProcessingContext:
model_config: ModelConfig
"""The configuration of the model."""
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
"""The tokenizer used to tokenize the inputs."""
@overload
@ -1146,7 +1147,7 @@ class BaseProcessingInfo:
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
return self.ctx.tokenizer
def get_hf_config(self) -> PretrainedConfig:

View File

@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache
from .processing import (
@ -231,17 +232,20 @@ class MultiModalRegistry:
def _create_processing_ctx(
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
if model_config.skip_tokenizer_init:
tokenizer = cast(TokenizerLike, object())
elif tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
*,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
@ -252,7 +256,7 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
*,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""

View File

@ -19,12 +19,12 @@ if TYPE_CHECKING:
DeltaMessage,
ResponsesRequest,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
ChatCompletionRequest = Any
DeltaMessage = Any
ResponsesRequest = Any
AnyTokenizer = Any
TokenizerLike = Any
logger = init_logger(__name__)
@ -37,7 +37,7 @@ class ReasoningParser:
It is used to extract reasoning content from the model output.
"""
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
self.model_tokenizer = tokenizer
@cached_property

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from vllm.entrypoints.openai.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import (
@ -43,7 +43,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
"""The token that ends reasoning content."""
raise NotImplementedError
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer:

View File

@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
@ -37,7 +37,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
Reasoning parser for MiniMax M2 model.
"""
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>")

View File

@ -6,7 +6,7 @@ from functools import cached_property
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
logger = init_logger(__name__)

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
import regex as re
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
@ -220,7 +220,7 @@ class Olmo3ReasoningParser(ReasoningParser):
token is missing from generation.
"""
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.think_start = r"<think>"

View File

@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__)
@ -477,7 +477,7 @@ class SamplingParams(
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
if not self.bad_words:
return
self._bad_words_token_ids = []

Some files were not shown because too many files have changed in this diff Show More