[Misc] Refactor tokenizer interface (#29693)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-29 20:02:21 +08:00 committed by GitHub
parent f223ed4181
commit 34a984274e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
119 changed files with 752 additions and 821 deletions

View File

@ -316,7 +316,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenization - tests/tokenizers_
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
@ -324,7 +324,7 @@ steps:
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately # OOM in the CI unless we run this separately
- pytest -v -s tokenization - pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45

View File

@ -282,7 +282,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenization - tests/tokenizers_
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
@ -290,7 +290,7 @@ steps:
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately # OOM in the CI unless we run this separately
- pytest -v -s tokenization - pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45

View File

@ -620,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False kwargs["use_fast"] = False
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
try: try:
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"MistralTokenizer requires vllm package.\n" "MistralTokenizer requires vllm package.\n"

View File

@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages # import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
DeltaMessage)
# define a reasoning parser and register it to vllm # define a reasoning parser and register it to vllm
# the name list in register_module can be used # the name list in register_module can be used
# in --reasoning-parser. # in --reasoning-parser.
class ExampleParser(ReasoningParser): class ExampleParser(ReasoningParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
def extract_reasoning_streaming( def extract_reasoning_streaming(

View File

@ -422,7 +422,7 @@ Here is a summary of a plugin file:
# in --tool-call-parser. you can define as many # in --tool-call-parser. you can define as many
# tool parsers as you want here. # tool parsers as you want here.
class ExampleToolParser(ToolParser): class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens # adjust request. e.g.: set skip special tokens

View File

@ -10,7 +10,7 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
@pytest.fixture() @pytest.fixture()

View File

@ -4,9 +4,9 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def default_tokenizer() -> AnyTokenizer: def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2") return AutoTokenizer.from_pretrained("gpt2")

View File

@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
@pytest.fixture @pytest.fixture
def qwen_tokenizer() -> AnyTokenizer: def qwen_tokenizer() -> TokenizerLike:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B") return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture @pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser: def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer) return Hermes2ProToolParser(qwen_tokenizer)
@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
def test_hermes_parser_streaming_just_forward_text( def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
def test_hermes_parser_streaming_failure_case_bug_19056( def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
def test_hermes_parser_streaming( def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:

View File

@ -7,11 +7,11 @@ import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
@pytest.fixture @pytest.fixture
def parser(default_tokenizer: AnyTokenizer): def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer) return Llama3JsonToolParser(default_tokenizer)

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# Test cases similar to pythonic parser but with Llama4 specific format # Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
) )
@ -208,7 +208,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
@ -224,7 +224,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
) )
@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
) )
@ -188,7 +188,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
@ -205,7 +205,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
) )
@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer

View File

@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
) )
@ -168,7 +168,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
@ -185,7 +185,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
) )
@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer

View File

@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
class StreamingToolReconstructor: class StreamingToolReconstructor:
@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
return tool_parser.extract_tool_calls(model_output, request) return tool_parser.extract_tool_calls(model_output, request)
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]: def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
# Split a string into a series of deltas using the provided tokenizer. Each # Split a string into a series of deltas using the provided tokenizer. Each
# delta will be the string equivalent of a single token. # delta will be the string equivalent of a single token.
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)

View File

@ -28,8 +28,8 @@ from vllm.multimodal.utils import (
encode_image_base64, encode_image_base64,
encode_video_base64, encode_video_base64,
) )
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH from ..utils import VLLM_PATH

View File

@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolParser, MistralToolParser,
) )
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from ...utils import check_logprobs_close from ...utils import check_logprobs_close

View File

@ -9,7 +9,7 @@ from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from ....conftest import AudioTestAssets from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer

View File

@ -9,7 +9,7 @@ import torch
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .....conftest import HfRunner, VllmRunner from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS from ....registry import HF_EXAMPLE_MODELS
@ -33,7 +33,7 @@ def run_test(
auto_cls: type[_BaseAutoModelClass], auto_cls: type[_BaseAutoModelClass],
use_tokenizer_eos: bool, use_tokenizer_eos: bool,
comparator: Callable[..., None], comparator: Callable[..., None],
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None, get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None,
stop_str: list[str] | None, stop_str: list[str] | None,
limit_mm_per_prompt: dict[str, int], limit_mm_per_prompt: dict[str, int],
vllm_runner_kwargs: dict[str, Any] | None, vllm_runner_kwargs: dict[str, Any] | None,

View File

@ -14,7 +14,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .....conftest import ( from .....conftest import (
AUDIO_ASSETS, AUDIO_ASSETS,
@ -126,7 +126,7 @@ class VLMTestInfo(NamedTuple):
vllm_runner_kwargs: dict[str, Any] | None = None vllm_runner_kwargs: dict[str, Any] | None = None
# Optional callable which gets a list of token IDs from the model tokenizer # Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None = None get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None = None
# Optional list of strings to stop generation, useful when stop tokens are # Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer # not special tokens in the tokenizer
stop_str: list[str] | None = None stop_str: list[str] | None = None

View File

@ -22,8 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config, cached_tokenizer_from_config,
encode_tokens, encode_tokens,
) )

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from contextlib import nullcontext from contextlib import nullcontext
from typing import cast from typing import cast
@ -23,7 +24,7 @@ from vllm.multimodal.processing import (
replace_token_matches, replace_token_matches,
) )
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .utils import random_image from .utils import random_image
@ -238,7 +239,7 @@ def test_find_token_matches(
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to token IDs # Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(0) key: update_type(key, target, []).resolve(0)
@ -385,7 +386,7 @@ def test_find_text_matches(
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(0) key: update_type(key, target, []).resolve(0)
@ -545,7 +546,7 @@ def test_find_update_text(
expected_by_update_type_mm_count, expected_by_update_type_mm_count,
): ):
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
for ( for (
update_type, update_type,
@ -750,7 +751,7 @@ def test_find_update_tokens(
expected_by_update_type_mm_count, expected_by_update_type_mm_count,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
for ( for (
update_type, update_type,
@ -900,7 +901,7 @@ def test_find_mm_placeholders(
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
mm_prompt_updates = { mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
@ -1029,7 +1030,7 @@ def test_hf_processor_init_kwargs(
expected_kwargs, expected_kwargs,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
@ -1065,7 +1066,7 @@ def test_hf_processor_call_kwargs(
expected_kwargs, expected_kwargs,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
@ -1088,9 +1089,7 @@ def test_apply_matches_no_match_exits_quickly():
With the fix, it should exit immediately when no match is found. With the fix, it should exit immediately when no match is found.
""" """
import time mock_tokenizer = cast(TokenizerLike, object())
mock_tokenizer = cast(AnyTokenizer, object())
# Create a long prompt with no placeholder # Create a long prompt with no placeholder
long_prompt = "x" * 10000 long_prompt = "x" * 10000

View File

@ -5,7 +5,7 @@ import pytest
from tests.reasoning.utils import run_reasoning_extraction_mistral from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
parser_name = "mistral" parser_name = "mistral"

View File

@ -4,7 +4,7 @@
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
class StreamingReasoningReconstructor: class StreamingReasoningReconstructor:

View File

@ -1,18 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = ["BAAI/bge-base-en"]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2

View File

@ -1,32 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
{meth}`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118

View File

@ -1,23 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = [
"facebook/opt-125m",
"gpt2",
]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
def test_tokenizer_revision(tokenizer_name: str):
# Assume that "main" branch always exists
tokenizer = get_tokenizer(tokenizer_name, revision="main")
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match="not a valid git identifier"):
get_tokenizer(tokenizer_name, revision="never")

View File

@ -1,120 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TestTokenizer(TokenizerBase):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()
@property
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
@property
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def is_fast(self) -> bool:
raise NotImplementedError()
@property
def vocab_size(self) -> int:
raise NotImplementedError()
@property
def max_token_id(self) -> int:
raise NotImplementedError()
@property
def truncation_side(self) -> str:
raise NotImplementedError()
def __call__(
self,
text: str | list[str] | list[int],
text_pair: str | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: int | None = None,
):
raise NotImplementedError()
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: int | None = None,
) -> list[int]:
raise NotImplementedError()
def encode(self, text: str, add_special_tokens: bool | None = None) -> list[int]:
raise NotImplementedError()
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
raise NotImplementedError()
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
raise NotImplementedError()
def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer"
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

View File

@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE: Since CI runs the tests from the `tests` directory, it is necessary to rename
# this module to avoid conflicting with HF's `tokenizers` package

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import _get_protocol_attrs # type: ignore
import pytest
from transformers import PreTrainedTokenizerBase
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_tokenizer
def _get_missing_attrs(obj: object, target: type):
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
def test_tokenizer_like_protocol():
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer("gpt2", use_fast=False),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer("gpt2", use_fast=True),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer(
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str):
# Assume that "main" branch always exists
tokenizer = get_tokenizer(tokenizer_name, revision="main")
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match="not a valid git identifier"):
get_tokenizer(tokenizer_name, revision="never")
@pytest.mark.parametrize("tokenizer_name", ["BAAI/bge-base-en"])
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2

View File

@ -6,7 +6,8 @@ from copy import deepcopy
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"]) @pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
@ -25,7 +26,7 @@ def test_cached_tokenizer(model_id: str):
_check_consistency(unpickled_tokenizer, reference_tokenizer) _check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer): def _check_consistency(target: TokenizerLike, expected: TokenizerLike):
assert isinstance(target, type(expected)) assert isinstance(target, type(expected))
# Cached attributes # Cached attributes

View File

@ -8,7 +8,7 @@ import pytest
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import ( from vllm.v1.engine.detokenizer import (
FastIncrementalDetokenizer, FastIncrementalDetokenizer,

View File

@ -7,7 +7,7 @@ import pytest
from mistral_common.exceptions import InvalidMessageStructureException from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.transformers_utils.tokenizers.mistral import ( from vllm.tokenizers.mistral import (
MistralTokenizer, MistralTokenizer,
_prepare_apply_chat_template_tools_and_messages, _prepare_apply_chat_template_tools_and_messages,
) )
@ -308,25 +308,6 @@ class TestMistralTokenizer:
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer): def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
assert mistral_tokenizer.get_added_vocab() == {} assert mistral_tokenizer.get_added_vocab() == {}
def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686]
)
assert mistral_tokenizer.encode_one("Hello world !") == token_ids
assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids
assert (
mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1)
== token_ids[:-2]
)
assert (
mistral_tokenizer.encode_one(
"Hello world !", truncation=False, max_length=1
)
== token_ids
)
assert mistral_tokenizer.encode_one("") == []
def test_encode(self, mistral_tokenizer: MistralTokenizer): def test_encode(self, mistral_tokenizer: MistralTokenizer):
token_ids = ( token_ids = (
[1, 22177, 4304, 2662] [1, 22177, 4304, 2662]

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.tokenizers import TokenizerLike, TokenizerRegistry
from vllm.transformers_utils.tokenizer import get_tokenizer
class TestTokenizer(TokenizerLike):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer() # type: ignore
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer",
__name__,
TestTokenizer.__name__,
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

View File

@ -14,8 +14,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available # Use a common model that is likely to be available
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking" MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
@ -173,7 +174,7 @@ def test_extract_tool_calls(
def stream_delta_message_generator( def stream_delta_message_generator(
ernie45_tool_parser: Ernie45ToolParser, ernie45_tool_parser: Ernie45ToolParser,
ernie45_tokenizer: AnyTokenizer, ernie45_tokenizer: TokenizerLike,
model_output: str, model_output: str,
request: ChatCompletionRequest | None = None, request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]: ) -> Generator[DeltaMessage, None, None]:

View File

@ -10,8 +10,9 @@ from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -44,7 +45,9 @@ def assert_tool_calls(
def stream_delta_message_generator( def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str jamba_tool_parser: JambaToolParser,
jamba_tokenizer: TokenizerLike,
model_output: str,
) -> Generator[DeltaMessage, None, None]: ) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False) all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)

View File

@ -17,8 +17,9 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser, Qwen3CoderToolParser,
) )
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -104,7 +105,7 @@ def assert_tool_calls(
def stream_delta_message_generator( def stream_delta_message_generator(
qwen3_tool_parser, qwen3_tool_parser,
qwen3_tokenizer: AnyTokenizer, qwen3_tokenizer: TokenizerLike,
model_output: str, model_output: str,
request: ChatCompletionRequest | None = None, request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]: ) -> Generator[DeltaMessage, None, None]:

View File

@ -15,8 +15,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -256,7 +257,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
def stream_delta_message_generator( def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser, seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: AnyTokenizer, seed_oss_tokenizer: TokenizerLike,
model_output: str, model_output: str,
request: ChatCompletionRequest | None = None, request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]: ) -> Generator[DeltaMessage, None, None]:

View File

@ -13,8 +13,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -49,7 +50,7 @@ def assert_tool_calls(
def stream_delta_message_generator( def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser, xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer, xlam_tokenizer: TokenizerLike,
model_output: str, model_output: str,
request: ChatCompletionRequest | None = None, request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]: ) -> Generator[DeltaMessage, None, None]:

View File

@ -1,62 +1,32 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
import tempfile def test_get_llama3_eos_token():
from pathlib import Path model_name = "meta-llama/Llama-3.2-1B-Instruct"
from unittest.mock import MagicMock, call, patch
import pytest tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
from vllm.transformers_utils.repo_utils import list_filtered_repo_files generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
@pytest.mark.parametrize( def test_get_blip2_eos_token():
"allow_patterns,expected_relative_files", model_name = "Salesforce/blip2-opt-2.7b"
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def _glob_path() -> list[str]: tokenizer = get_tokenizer(model_name)
return [ assert tokenizer.eos_token_id == 2
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn generation_config = try_get_generation_config(model_name, trust_remote_code=False)
with patch( assert generation_config is not None
"vllm.transformers_utils.repo_utils.list_repo_files", assert generation_config.eos_token_id == 50118
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
@pytest.mark.parametrize(
"allow_patterns,expected_relative_files",
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)

View File

@ -18,7 +18,7 @@ from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import ( from vllm.v1.engine import (
EngineCoreEvent, EngineCoreEvent,
EngineCoreEventType, EngineCoreEventType,
@ -31,7 +31,7 @@ from vllm.v1.metrics.stats import IterationStats, SchedulerStats
def _ref_convert_id_to_token( def _ref_convert_id_to_token(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
token_id: int, token_id: int,
) -> str: ) -> str:
"""Reference impl of logprobs detokenization. """Reference impl of logprobs detokenization.

View File

@ -27,8 +27,8 @@ ALLOWED_FILES = {
"vllm/distributed/device_communicators/shm_broadcast.py", "vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py", "vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py", "vllm/utils/hashing.py",
"tests/tokenizers_/test_cached_tokenizer.py",
"tests/utils_/test_hashing.py", "tests/utils_/test_hashing.py",
"tests/tokenization/test_cached_tokenizer.py",
"benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py", "benchmarks/kernels/benchmark_lora.py",
"benchmarks/kernels/benchmark_machete.py", "benchmarks/kernels/benchmark_machete.py",

View File

@ -35,6 +35,7 @@ FILES = [
"vllm/multimodal", "vllm/multimodal",
"vllm/platforms", "vllm/platforms",
"vllm/plugins", "vllm/plugins",
"vllm/tokenizers",
"vllm/transformers_utils", "vllm/transformers_utils",
"vllm/triton_utils", "vllm/triton_utils",
"vllm/usage", "vllm/usage",

View File

@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
try: try:
@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str:
# Global cache for LoRA tokenizers. # Global cache for LoRA tokenizers.
lora_tokenizer_cache: dict[int, AnyTokenizer] = {} lora_tokenizer_cache: dict[int, TokenizerLike] = {}
def process_image(image: Any) -> Mapping[str, Any]: def process_image(image: Any) -> Mapping[str, Any]:

View File

@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
@ -85,7 +85,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def get_tokenizer(self) -> AnyTokenizer: async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer""" """Get the tokenizer"""
... ...

View File

@ -49,9 +49,9 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
@ -536,7 +536,7 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format( def _resolve_chat_template_content_format(
chat_template: str | None, chat_template: str | None,
tools: list[dict[str, Any]] | None, tools: list[dict[str, Any]] | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
model_config: ModelConfig, model_config: ModelConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
chat_template: str | None, chat_template: str | None,
tools: list[dict[str, Any]] | None, tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
model_config: ModelConfig, model_config: ModelConfig,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt. maximum per prompt.
""" """
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
super().__init__() super().__init__()
self._model_config = model_config self._model_config = model_config
@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages( def parse_chat_messages(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
@ -1624,7 +1624,7 @@ def parse_chat_messages(
def parse_chat_messages_futures( def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],

View File

@ -71,11 +71,8 @@ from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import ( from vllm.tokenizers import MistralTokenizer, TokenizerLike
AnyTokenizer, from vllm.transformers_utils.tokenizer import get_cached_tokenizer
MistralTokenizer,
get_cached_tokenizer,
)
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
@ -350,11 +347,11 @@ class LLM:
self.input_processor = self.llm_engine.input_processor self.input_processor = self.llm_engine.input_processor
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.") @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
# While CachedTokenizer is dynamic, have no choice but # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from # compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached' # user-defined tokenizer started with 'Cached'
@ -1244,7 +1241,7 @@ class LLM:
def _embedding_score( def _embedding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt], text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt], text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,
@ -1276,7 +1273,7 @@ class LLM:
def _cross_encoding_score( def _cross_encoding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam], data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None, truncate_prompt_tokens: int | None = None,

View File

@ -62,8 +62,9 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizers import ( from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls, maybe_serialize_tool_calls,
truncate_tool_call_ids, truncate_tool_call_ids,
validate_request_params, validate_request_params,
@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time()) created_time = int(time.time())
@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
logprobs: dict[int, Logprob], logprobs: dict[int, Logprob],
top_logprobs: int | None, top_logprobs: int | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
should_return_as_token_id: bool, should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]: ) -> list[ChatCompletionLogProb]:
return [ return [
@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None], top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
num_output_top_logprobs: int | None = None, num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None, return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:

View File

@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin):
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):

View File

@ -33,7 +33,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> CompletionResponse: ) -> CompletionResponse:
choices: list[CompletionResponseChoice] = [] choices: list[CompletionResponseChoice] = []
@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None], top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
initial_text_offset: int = 0, initial_text_offset: int = 0,
return_as_token_id: bool | None = None, return_as_token_id: bool | None = None,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if should_return_as_token_id: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)

View File

@ -7,13 +7,14 @@ import time
import traceback import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch import torch
from fastapi import Request from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs from typing_extensions import TypeIs
@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tracing import ( from vllm.tracing import (
contains_trace_headers, contains_trace_headers,
extract_trace_headers, extract_trace_headers,
log_tracing_disabled_warning, log_tracing_disabled_warning,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import ( from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer, AsyncMicrobatchTokenizer,
@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
RequestT = TypeVar("RequestT", bound=AnyRequest) RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel): @dataclass(kw_only=True)
class RequestProcessingMixin:
""" """
Mixin for request processing, Mixin for request processing,
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Sequence[RequestPrompt] | None = [] request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = [] engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ResponseGenerationMixin(BaseModel): @dataclass(kw_only=True)
class ResponseGenerationMixin:
""" """
Mixin for response generation, Mixin for response generation,
managing result generators and final batch results. managing result generators and final batch results.
@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel):
result_generator: ( result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None ) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field( final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list default_factory=list
) )
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext( @dataclass(kw_only=True)
RequestProcessingMixin, class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
ResponseGenerationMixin,
BaseModel,
Generic[RequestT],
):
# Shared across all requests # Shared across all requests
request: RequestT request: RequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
created_time: int = Field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
# Shared across most requests # Shared across most requests
tokenizer: AnyTokenizer | None = None tokenizer: TokenizerLike | None = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
ClassificationServeContext = ServeContext[ClassificationRequest] @dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]): class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing: class OpenAIServing:
request_id_prefix: ClassVar[str] = """ request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify") A short string prepended to every requests ID (e.g. "embd", "classify")
@ -281,7 +266,7 @@ class OpenAIServing:
apply_mistral_chat_template, executor=self._tokenizer_executor apply_mistral_chat_template, executor=self._tokenizer_executor
) )
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor self.input_processor = self.models.input_processor
@ -291,7 +276,7 @@ class OpenAIServing:
def _get_tool_parser( def _get_tool_parser(
self, tool_parser_name: str | None = None, enable_auto_tools: bool = False self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
) -> Callable[[AnyTokenizer], ToolParser] | None: ) -> Callable[[TokenizerLike], ToolParser] | None:
"""Get the tool parser based on the name.""" """Get the tool parser based on the name."""
parser = None parser = None
if not enable_auto_tools or tool_parser_name is None: if not enable_auto_tools or tool_parser_name is None:
@ -317,7 +302,7 @@ class OpenAIServing:
def _get_reasoning_parser( def _get_reasoning_parser(
self, self,
reasoning_parser_name: str, reasoning_parser_name: str,
) -> Callable[[AnyTokenizer], ReasoningParser] | None: ) -> Callable[[TokenizerLike], ReasoningParser] | None:
"""Get the reasoning parser based on the name.""" """Get the reasoning parser based on the name."""
parser = None parser = None
if not reasoning_parser_name: if not reasoning_parser_name:
@ -547,7 +532,7 @@ class OpenAIServing:
prompt_logprobs=None, prompt_logprobs=None,
) )
def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer: def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
""" """
Get a Renderer instance with the provided tokenizer. Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency. Uses shared async tokenizer pool for efficiency.
@ -877,7 +862,7 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
prompt: str, prompt: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
add_special_tokens: bool, add_special_tokens: bool,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer) async_tokenizer = self._get_async_tokenizer(tokenizer)
@ -919,7 +904,7 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
prompt_ids: list[int], prompt_ids: list[int],
tokenizer: AnyTokenizer | None, tokenizer: TokenizerLike | None,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
@ -1015,7 +1000,7 @@ class OpenAIServing:
async def _tokenize_prompt_input_async( async def _tokenize_prompt_input_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt_input: str | list[int], prompt_input: str | list[int],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> TextTokensPrompt: ) -> TextTokensPrompt:
@ -1034,7 +1019,7 @@ class OpenAIServing:
async def _tokenize_prompt_inputs_async( async def _tokenize_prompt_inputs_async(
self, self,
request: AnyRequest, request: AnyRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]], prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]: ) -> AsyncGenerator[TextTokensPrompt, None]:
@ -1079,7 +1064,7 @@ class OpenAIServing:
async def _preprocess_chat( async def _preprocess_chat(
self, self,
request: ChatLikeRequest | ResponsesRequest, request: ChatLikeRequest | ResponsesRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
@ -1088,13 +1073,18 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None, documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
Sequence[RequestPrompt], Sequence[RequestPrompt],
list[EngineTokensPrompt], list[EngineTokensPrompt],
]: ]:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
@ -1370,9 +1360,9 @@ class OpenAIServing:
@staticmethod @staticmethod
def _parse_tool_calls_from_content( def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest, request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
enable_auto_tools: bool, enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None, content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]: ) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]() function_calls = list[FunctionCall]()
@ -1442,7 +1432,7 @@ class OpenAIServing:
def _get_decoded_token( def _get_decoded_token(
logprob: Logprob, logprob: Logprob,
token_id: int, token_id: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
return_as_token_id: bool = False, return_as_token_id: bool = False,
) -> str: ) -> str:
if return_as_token_id: if return_as_token_id:
@ -1450,6 +1440,12 @@ class OpenAIServing:
if logprob.decoded_token is not None: if logprob.decoded_token is not None:
return logprob.decoded_token return logprob.decoded_token
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return tokenizer.decode(token_id) return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: str | None) -> bool: def _is_model_supported(self, model_name: str | None) -> bool:

View File

@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
prev_response: ResponsesResponse | None, prev_response: ResponsesResponse | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
): ):
if request.tools is None or ( if request.tools is None or (
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext], result_generator: AsyncIterator[ConversationContext],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int | None = None, created_time: int | None = None,
) -> ErrorResponse | ResponsesResponse: ) -> ErrorResponse | ResponsesResponse:
@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
logprobs: dict[int, SampleLogprob], logprobs: dict[int, SampleLogprob],
top_logprobs: int, top_logprobs: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> list[LogprobTopLogprob]: ) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary.""" """Returns the top-k logprobs from the logprobs dictionary."""
out = [] out = []
@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
token_ids: Sequence[int], token_ids: Sequence[int],
logprobs: SampleLogprobs | None, logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
top_logprobs: int | None = None, top_logprobs: int | None = None,
) -> list[Logprob]: ) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided" assert logprobs is not None, "logprobs must be provided"
@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
token_ids: Sequence[int], token_ids: Sequence[int],
logprobs: SampleLogprobs | None, logprobs: SampleLogprobs | None,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
top_logprobs: int | None = None, top_logprobs: int | None = None,
) -> list[response_text_delta_event.Logprob]: ) -> list[response_text_delta_event.Logprob]:
lgs = self._create_response_logprobs( lgs = self._create_response_logprobs(
@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
final_output: CompletionOutput, final_output: CompletionOutput,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> list[ResponseOutputItem]: ) -> list[ResponseOutputItem]:
if self.reasoning_parser: if self.reasoning_parser:
try: try:
@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None], result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_increment_sequence_number_and_return: Callable[ _increment_sequence_number_and_return: Callable[
@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None], result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_increment_sequence_number_and_return: Callable[ _increment_sequence_number_and_return: Callable[
@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing):
result_generator: AsyncIterator[ConversationContext | None], result_generator: AsyncIterator[ConversationContext | None],
context: ConversationContext, context: ConversationContext,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int | None = None, created_time: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]: ) -> AsyncGenerator[StreamingResponsesResponse, None]:

View File

@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils.async_utils import make_async, merge_async_iterators from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
@ -60,7 +60,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score( async def _embedding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
texts_1: list[str], texts_1: list[str],
texts_2: list[str], texts_2: list[str],
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,
@ -153,7 +153,7 @@ class ServingScores(OpenAIServing):
def _preprocess_score( def _preprocess_score(
self, self,
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any], tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
@ -175,7 +175,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score( async def _cross_encoding_score(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam], data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam], data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,

View File

@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing):
@dataclass @dataclass
class TokenizerInfo: class TokenizerInfo:
tokenizer: AnyTokenizer tokenizer: TokenizerLike
chat_template: str | None chat_template: str | None
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:

View File

@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.sampling_params import ( from vllm.sampling_params import (
StructuredOutputsParams, StructuredOutputsParams,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path from vllm.utils.import_utils import import_from_path
@ -36,7 +36,7 @@ class ToolParser:
derived classes. derived classes.
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed # the index of the tool call that is currently being parsed
self.current_tool_id: int = -1 self.current_tool_id: int = -1

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser): class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser): class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser): class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
""" """
Ernie thinking model format: Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser): class Glm4MoeModelToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent = False self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []

View File

@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads, partial_json_loads,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.bot_token = "<function_call>" self.bot_token = "<function_call>"

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
partial_json_loads, partial_json_loads,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>` # for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>" self.bot_token = "<|tool_call|>"

View File

@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser): class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model") logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []

View File

@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.entrypoints.openai.tool_parsers.utils import consume_space
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser): class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# Initialize state for streaming mode # Initialize state for streaming mode

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser): class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.position = 0 self.position = 0

View File

@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizers import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
class JambaToolParser(ToolParser): class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer): if isinstance(self.model_tokenizer, MistralTokenizer):

View File

@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser): class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []

View File

@ -4,11 +4,11 @@
import regex as re import regex as re
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
class LongcatFlashToolParser(Hermes2ProToolParser): class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.tool_call_start_token: str = "<longcat_tool_call>" self.tool_call_start_token: str = "<longcat_tool_call>"

View File

@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser): class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class MinimaxToolParser(ToolParser): class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# Initialize streaming state for tracking tool call progress # Initialize streaming state for tracking tool call progress

View File

@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -46,7 +46,7 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9 return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return ( return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
) )
@ -61,7 +61,7 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer): if not isinstance(self.model_tokenizer, MistralTokenizer):

View File

@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
else: else:
AnyTokenizer = object TokenizerLike = object
logger = init_logger(__name__) logger = init_logger(__name__)
class OpenAIToolParser(ToolParser): class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "AnyTokenizer"): def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer) super().__init__(tokenizer)
def extract_tool_calls( def extract_tool_calls(

View File

@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser): class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False

View File

@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser): class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser() self.parser = StreamingXMLToolCallParser()

View File

@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>" TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>" TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# --- streaming state --- # --- streaming state ---

View File

@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<tool_sep>" TOOL_SEP = "<tool_sep>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.position = 0 self.position = 0
# Explicit state flags for robust streaming # Explicit state flags for robust streaming

View File

@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
class xLAMToolParser(ToolParser): class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# Initialize state for streaming mode # Initialize state for streaming mode

View File

@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@ -85,7 +85,7 @@ class BaseRenderer(ABC):
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
): ):
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config
@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer):
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None, | None = None,
): ):
super().__init__(model_config, tokenizer) super().__init__(model_config, tokenizer)
@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer):
return async_tokenizer return async_tokenizer
tokenizer = self.tokenizer tokenizer = self.tokenizer
if self.tokenizer is None: if tokenizer is None:
raise ValueError("No tokenizer available for text input processing") raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None: if self.async_tokenizer_pool is None:

View File

@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import TokenizerLike
AnyTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
ScoreContentPartParam: TypeAlias = ( ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False):
def _cosine_similarity( def _cosine_similarity(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput], embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput], embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
@ -93,7 +89,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[str, str, MultiModalDataDict | None]: ) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config, tokenizer) mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@ -118,12 +114,14 @@ def _parse_score_content(
mm_tracker: BaseMultiModalItemTracker, mm_tracker: BaseMultiModalItemTracker,
) -> _ContentPart | None: ) -> _ContentPart | None:
if isinstance(data, str): if isinstance(data, str):
data = ChatCompletionContentPartTextParam(type="text", text=data) part = ChatCompletionContentPartTextParam(type="text", text=data)
else:
part = data
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
parse_res = _parse_chat_message_content_part( parse_res = _parse_chat_message_content_part(
data, part,
mm_parser, mm_parser,
wrap_dicts=False, wrap_dicts=False,
interleave_strings=False, interleave_strings=False,
@ -181,7 +179,7 @@ def post_process_tokens(
def get_score_prompt( def get_score_prompt(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any], tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,

View File

@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
@ -46,7 +46,7 @@ class InputPreprocessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer | None, tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None, mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None: ) -> None:
@ -59,7 +59,7 @@ class InputPreprocessor:
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True" "You cannot pass text prompts when `skip_tokenizer_init` is True"
@ -228,11 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs) return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> AnyTokenizer: def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
if not self.tokenizer: if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
return tokenizer return tokenizer

View File

@ -5,7 +5,7 @@ from typing import TypeAlias
import torch import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
LogitsProcessor: TypeAlias = ( LogitsProcessor: TypeAlias = (
Callable[[list[int], torch.Tensor], torch.Tensor] Callable[[list[int], torch.Tensor], torch.Tensor]
@ -19,7 +19,7 @@ to sample from."""
def get_bad_words_logits_processors( def get_bad_words_logits_processors(
bad_words: list[str], tokenizer: AnyTokenizer bad_words: list[str], tokenizer: TokenizerLike
) -> list[LogitsProcessor]: ) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list() bad_words_ids: list[list[int]] = list()

View File

@ -28,7 +28,7 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
from .internvl import ( from .internvl import (
@ -241,7 +241,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,

View File

@ -50,7 +50,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
@ -347,7 +347,7 @@ class BaseInternVLProcessor(ABC):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,
@ -561,7 +561,7 @@ class InternVLProcessor(BaseInternVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,

View File

@ -73,9 +73,9 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
cached_tokenizer_from_config, cached_tokenizer_from_config,
encode_tokens, encode_tokens,
) )
@ -284,7 +284,7 @@ class BaseNanoNemotronVLProcessor(ABC):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*args, *args,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
**kwargs, **kwargs,
@ -434,7 +434,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
@ -645,7 +645,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame: list[int], tokens_per_frame: list[int],
frames_indices: list[int], frames_indices: list[int],
frame_duration_ms: int, frame_duration_ms: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
img_start_token_ids: list[int], img_start_token_ids: list[int],
img_end_token_ids: list[int], img_end_token_ids: list[int],
img_context_token_ids: list[int], img_context_token_ids: list[int],
@ -670,7 +670,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame (list[int]): number of tokens per frame tokens_per_frame (list[int]): number of tokens per frame
frames_indices (list[int]): frame indices frames_indices (list[int]): frame indices
frame_duration_ms (int): duration of each frame in milliseconds frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens

View File

@ -34,8 +34,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_image_processor_from_config from vllm.transformers_utils.processor import cached_image_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
@ -203,7 +203,7 @@ class NemotronVLProcessor(InternVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
image_processor: BaseImageProcessorFast, image_processor: BaseImageProcessorFast,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,

View File

@ -31,7 +31,7 @@ from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .qwen2_5_vl import ( from .qwen2_5_vl import (
Qwen2_5_VisionTransformer as OpenCUAVisionTransformer, Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
@ -79,7 +79,7 @@ class OpenCUAProcessor(Qwen2VLProcessor):
def __init__( def __init__(
self, self,
vision_config: dict, vision_config: dict,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
**kwargs, **kwargs,
): ):
image_processor = Qwen2VLImageProcessor(**vision_config) image_processor = Qwen2VLImageProcessor(**vision_config)

View File

@ -59,10 +59,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import ( from vllm.tokenizers import MistralTokenizer
MistralTokenizer, from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
cached_tokenizer_from_config,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

View File

@ -91,7 +91,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
@ -1533,7 +1533,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
def __init__( def __init__(
self, self,
vision_config: dict, vision_config: dict,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
**kwargs, **kwargs,
): ):
self.image_processor = Tarsier2ImageProcessor(**vision_config) self.image_processor = Tarsier2ImageProcessor(**vision_config)

View File

@ -47,7 +47,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -282,7 +282,7 @@ class SkyworkR1VProcessor:
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,

View File

@ -43,8 +43,8 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -321,7 +321,7 @@ class Step3VLProcessor:
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> None: ) -> None:
super().__init__() super().__init__()

View File

@ -51,10 +51,8 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import ( from vllm.tokenizers import MistralTokenizer
MistralTokenizer, from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
cached_tokenizer_from_config,
)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix from .utils import init_vllm_registered_model, maybe_prefix

View File

@ -23,8 +23,9 @@ import torch
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves from vllm.utils.jsontree import JSONTree, json_map_leaves
@ -76,7 +77,7 @@ PromptSeq: TypeAlias = str | list[int]
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_encode( def _cached_encode(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
text: str, text: str,
*, *,
add_special_tokens: bool | None = None, add_special_tokens: bool | None = None,
@ -86,7 +87,7 @@ def _cached_encode(
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_decode( def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
token_ids: tuple[int, ...], token_ids: tuple[int, ...],
*, *,
skip_special_tokens: bool | None = None, skip_special_tokens: bool | None = None,
@ -96,14 +97,14 @@ def _cached_decode(
) )
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
if isinstance(seq, str): if isinstance(seq, str):
return seq return seq
return _cached_decode(tokenizer, tuple(seq)) return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
if isinstance(seq, str): if isinstance(seq, str):
return _cached_encode(tokenizer, seq, add_special_tokens=False) return _cached_encode(tokenizer, seq, add_special_tokens=False)
@ -113,7 +114,7 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
class _GetMatchIndex(Protocol): class _GetMatchIndex(Protocol):
def __call__( def __call__(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0, start_idx: int = 0,
) -> int | None: ... ) -> int | None: ...
@ -143,7 +144,7 @@ class PromptIndexTargets:
""" """
def get_match_index( def get_match_index(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0, start_idx: int = 0,
) -> int | None: ) -> int | None:
@ -199,7 +200,7 @@ class PromptUpdateDetails(Generic[_S]):
full: _S full: _S
"""The full content.""" """The full content."""
is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
""" """
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions return a boolean mask of shape `(len(full),)` indicating which positions
@ -220,7 +221,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_text: str, embed_text: str,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text) embed_token_ids = encode_tokens(tokenizer, embed_text)
token_ids = _seq2tokens(tokenizer, full) token_ids = _seq2tokens(tokenizer, full)
@ -236,7 +237,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_token_id: int, embed_token_id: int,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full) token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id return torch.tensor(token_ids) == embed_token_id
@ -522,7 +523,7 @@ class ResolvedPromptUpdate:
def iter_token_matches( def iter_token_matches(
self, self,
prompt: list[int], prompt: list[int],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
@ -544,7 +545,7 @@ class ResolvedPromptUpdate:
def iter_text_matches( def iter_text_matches(
self, self,
prompt: str, prompt: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
@ -566,7 +567,7 @@ class ResolvedPromptUpdate:
def iter_matches( def iter_matches(
self, self,
prompt: list[int] | str, prompt: list[int] | str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
@ -675,7 +676,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
def _find_matches( def _find_matches(
prompt: _S, prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
prev_end_idx: int = 0, prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult", current_result: "MultiModalPromptUpdatesApplyResult",
@ -740,7 +741,7 @@ def _all_items_found(
def _apply_matches( def _apply_matches(
prompt: _S, prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
@ -806,7 +807,7 @@ def _apply_matches(
def apply_token_matches( def apply_token_matches(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
""" """
Apply the updates in `mm_prompt_updates` to `prompt`. Apply the updates in `mm_prompt_updates` to `prompt`.
@ -823,7 +824,7 @@ def apply_token_matches(
def apply_text_matches( def apply_text_matches(
prompt: str, prompt: str,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
""" """
Apply the updates in `mm_prompt_updates` to `prompt`. Apply the updates in `mm_prompt_updates` to `prompt`.
@ -840,7 +841,7 @@ def apply_text_matches(
def _iter_placeholders( def _iter_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> Iterable[PlaceholderFeaturesInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
""" """
Yield each set of placeholder tokens found in `prompt`. Yield each set of placeholder tokens found in `prompt`.
@ -909,7 +910,7 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@ -930,7 +931,7 @@ class InputProcessingContext:
model_config: ModelConfig model_config: ModelConfig
"""The configuration of the model.""" """The configuration of the model."""
tokenizer: AnyTokenizer tokenizer: TokenizerLike
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
@overload @overload
@ -1146,7 +1147,7 @@ class BaseProcessingInfo:
def model_id(self) -> str: def model_id(self) -> str:
return self.ctx.model_config.model return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> TokenizerLike:
return self.ctx.tokenizer return self.ctx.tokenizer
def get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> PretrainedConfig:

View File

@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .processing import ( from .processing import (
@ -231,17 +232,20 @@ class MultiModalRegistry:
def _create_processing_ctx( def _create_processing_ctx(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext: ) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init: if model_config.skip_tokenizer_init:
tokenizer = cast(TokenizerLike, object())
elif tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer) return InputProcessingContext(model_config, tokenizer)
def _create_processing_info( def _create_processing_info(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo: ) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
@ -252,7 +256,7 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """

View File

@ -19,12 +19,12 @@ if TYPE_CHECKING:
DeltaMessage, DeltaMessage,
ResponsesRequest, ResponsesRequest,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
else: else:
ChatCompletionRequest = Any ChatCompletionRequest = Any
DeltaMessage = Any DeltaMessage = Any
ResponsesRequest = Any ResponsesRequest = Any
AnyTokenizer = Any TokenizerLike = Any
logger = init_logger(__name__) logger = init_logger(__name__)
@ -37,7 +37,7 @@ class ReasoningParser:
It is used to extract reasoning content from the model output. It is used to extract reasoning content from the model output.
""" """
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
@cached_property @cached_property

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from vllm.entrypoints.openai.protocol import DeltaMessage from vllm.entrypoints.openai.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -43,7 +43,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
"""The token that ends reasoning content.""" """The token that ends reasoning content."""
raise NotImplementedError raise NotImplementedError
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer: if not self.model_tokenizer:

View File

@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
@ -37,7 +37,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
Reasoning parser for MiniMax M2 model. Reasoning parser for MiniMax M2 model.
""" """
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>") self.end_token_id = self.vocab.get("</think>")

View File

@ -6,7 +6,7 @@ from functools import cached_property
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
import regex as re import regex as re
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
@ -220,7 +220,7 @@ class Olmo3ReasoningParser(ReasoningParser):
token is missing from generation. token is missing from generation.
""" """
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
self.think_start = r"<think>" self.think_start = r"<think>"

View File

@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.v1.serial_utils import PydanticMsgspecMixin from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__) logger = init_logger(__name__)
@ -477,7 +477,7 @@ class SamplingParams(
eos_ids.update(self.stop_token_ids) eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids) self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
if not self.bad_words: if not self.bad_words:
return return
self._bad_words_token_ids = [] self._bad_words_token_ids = []

Some files were not shown because too many files have changed in this diff Show More