mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:27:13 +08:00
[Misc] Refactor tokenizer interface (#29693)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f223ed4181
commit
34a984274e
@ -316,7 +316,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
- tests/tokenization
|
||||
- tests/tokenizers_
|
||||
- tests/test_sequence
|
||||
- tests/test_config
|
||||
- tests/test_logger
|
||||
@ -324,7 +324,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
- pytest -v -s tokenizers_
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
|
||||
@ -282,7 +282,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
- tests/tokenization
|
||||
- tests/tokenizers_
|
||||
- tests/test_sequence
|
||||
- tests/test_config
|
||||
- tests/test_logger
|
||||
@ -290,7 +290,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
- pytest -v -s tokenizers_
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
|
||||
@ -620,7 +620,7 @@ def get_tokenizer(
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MistralTokenizer requires vllm package.\n"
|
||||
|
||||
@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
|
||||
# import the required packages
|
||||
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
|
||||
# define a reasoning parser and register it to vllm
|
||||
# the name list in register_module can be used
|
||||
# in --reasoning-parser.
|
||||
class ExampleParser(ReasoningParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
|
||||
@ -422,7 +422,7 @@ Here is a summary of a plugin file:
|
||||
# in --tool-call-parser. you can define as many
|
||||
# tool parsers as you want here.
|
||||
class ExampleToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# adjust request. e.g.: set skip special tokens
|
||||
|
||||
@ -10,7 +10,7 @@ import pytest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@ -4,9 +4,9 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def default_tokenizer() -> AnyTokenizer:
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> AnyTokenizer:
|
||||
def qwen_tokenizer() -> TokenizerLike:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
|
||||
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
|
||||
|
||||
@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
|
||||
@ -7,11 +7,11 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(default_tokenizer: AnyTokenizer):
|
||||
def parser(default_tokenizer: TokenizerLike):
|
||||
return Llama3JsonToolParser(default_tokenizer)
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
# Test cases similar to pythonic parser but with Llama4 specific format
|
||||
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
||||
@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@ -208,7 +208,7 @@ def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: AnyTokenizer,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
@ -224,7 +224,7 @@ def test_tool_call(
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
@ -188,7 +188,7 @@ def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: AnyTokenizer,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
@ -205,7 +205,7 @@ def test_tool_call(
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@ -168,7 +168,7 @@ def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: AnyTokenizer,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
@ -185,7 +185,7 @@ def test_tool_call(
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class StreamingToolReconstructor:
|
||||
@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
|
||||
return tool_parser.extract_tool_calls(model_output, request)
|
||||
|
||||
|
||||
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]:
|
||||
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
|
||||
# Split a string into a series of deltas using the provided tokenizer. Each
|
||||
# delta will be the string equivalent of a single token.
|
||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
@ -28,8 +28,8 @@ from vllm.multimodal.utils import (
|
||||
encode_image_base64,
|
||||
encode_video_base64,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import VLLM_PATH
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolParser,
|
||||
)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
from ....conftest import AudioTestAssets
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .....conftest import HfRunner, VllmRunner
|
||||
from ....registry import HF_EXAMPLE_MODELS
|
||||
@ -33,7 +33,7 @@ def run_test(
|
||||
auto_cls: type[_BaseAutoModelClass],
|
||||
use_tokenizer_eos: bool,
|
||||
comparator: Callable[..., None],
|
||||
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None,
|
||||
get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None,
|
||||
stop_str: list[str] | None,
|
||||
limit_mm_per_prompt: dict[str, int],
|
||||
vllm_runner_kwargs: dict[str, Any] | None,
|
||||
|
||||
@ -14,7 +14,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .....conftest import (
|
||||
AUDIO_ASSETS,
|
||||
@ -126,7 +126,7 @@ class VLMTestInfo(NamedTuple):
|
||||
vllm_runner_kwargs: dict[str, Any] | None = None
|
||||
|
||||
# Optional callable which gets a list of token IDs from the model tokenizer
|
||||
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None = None
|
||||
get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None = None
|
||||
# Optional list of strings to stop generation, useful when stop tokens are
|
||||
# not special tokens in the tokenizer
|
||||
stop_str: list[str] | None = None
|
||||
|
||||
@ -22,8 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from typing import cast
|
||||
|
||||
@ -23,7 +24,7 @@ from vllm.multimodal.processing import (
|
||||
replace_token_matches,
|
||||
)
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .utils import random_image
|
||||
|
||||
@ -238,7 +239,7 @@ def test_find_token_matches(
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
@ -385,7 +386,7 @@ def test_find_text_matches(
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
@ -545,7 +546,7 @@ def test_find_update_text(
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
for (
|
||||
update_type,
|
||||
@ -750,7 +751,7 @@ def test_find_update_tokens(
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
for (
|
||||
update_type,
|
||||
@ -900,7 +901,7 @@ def test_find_mm_placeholders(
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
mm_prompt_updates = {
|
||||
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
|
||||
@ -1029,7 +1030,7 @@ def test_hf_processor_init_kwargs(
|
||||
expected_kwargs,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
ctx = InputProcessingContext(
|
||||
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
|
||||
@ -1065,7 +1066,7 @@ def test_hf_processor_call_kwargs(
|
||||
expected_kwargs,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
ctx = InputProcessingContext(
|
||||
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
|
||||
@ -1088,9 +1089,7 @@ def test_apply_matches_no_match_exits_quickly():
|
||||
|
||||
With the fix, it should exit immediately when no match is found.
|
||||
"""
|
||||
import time
|
||||
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
mock_tokenizer = cast(TokenizerLike, object())
|
||||
|
||||
# Create a long prompt with no placeholder
|
||||
long_prompt = "x" * 10000
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction_mistral
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
parser_name = "mistral"
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
|
||||
class StreamingReasoningReconstructor:
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
TOKENIZER_NAMES = ["BAAI/bge-base-en"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
|
||||
@pytest.mark.parametrize("n_tokens", [510])
|
||||
def test_special_tokens(tokenizer_name: str, n_tokens: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||
|
||||
prompts = "[UNK]" * n_tokens
|
||||
prompt_token_ids = tokenizer.encode(prompts)
|
||||
assert len(prompt_token_ids) == n_tokens + 2
|
||||
@ -1,32 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
{meth}`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def test_get_llama3_eos_token():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128008, 128009]
|
||||
|
||||
|
||||
def test_get_blip2_eos_token():
|
||||
model_name = "Salesforce/blip2-opt-2.7b"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
||||
@ -1,23 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
TOKENIZER_NAMES = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
|
||||
def test_tokenizer_revision(tokenizer_name: str):
|
||||
# Assume that "main" branch always exists
|
||||
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
||||
|
||||
# Assume that "never" branch always does not exist
|
||||
with pytest.raises(OSError, match="not a valid git identifier"):
|
||||
get_tokenizer(tokenizer_name, revision="never")
|
||||
@ -1,120 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class TestTokenizer(TokenizerBase):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
|
||||
return TestTokenizer()
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | list[int],
|
||||
text_pair: str | None = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: int | None = None,
|
||||
) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool | None = None) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> list[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def test_customized_tokenizer():
|
||||
TokenizerRegistry.register(
|
||||
"test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer"
|
||||
)
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
|
||||
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
4
tests/tokenizers_/__init__.py
Normal file
4
tests/tokenizers_/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# NOTE: Since CI runs the tests from the `tests` directory, it is necessary to rename
|
||||
# this module to avoid conflicting with HF's `tokenizers` package
|
||||
59
tests/tokenizers_/test_basic.py
Normal file
59
tests/tokenizers_/test_basic.py
Normal file
@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import _get_protocol_attrs # type: ignore
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def _get_missing_attrs(obj: object, target: type):
|
||||
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
|
||||
|
||||
|
||||
def test_tokenizer_like_protocol():
|
||||
assert not (
|
||||
missing_attrs := _get_missing_attrs(
|
||||
get_tokenizer("gpt2", use_fast=False),
|
||||
TokenizerLike,
|
||||
)
|
||||
), f"Missing attrs: {missing_attrs}"
|
||||
|
||||
assert not (
|
||||
missing_attrs := _get_missing_attrs(
|
||||
get_tokenizer("gpt2", use_fast=True),
|
||||
TokenizerLike,
|
||||
)
|
||||
), f"Missing attrs: {missing_attrs}"
|
||||
|
||||
assert not (
|
||||
missing_attrs := _get_missing_attrs(
|
||||
get_tokenizer(
|
||||
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
|
||||
),
|
||||
TokenizerLike,
|
||||
)
|
||||
), f"Missing attrs: {missing_attrs}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
|
||||
def test_tokenizer_revision(tokenizer_name: str):
|
||||
# Assume that "main" branch always exists
|
||||
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
||||
|
||||
# Assume that "never" branch always does not exist
|
||||
with pytest.raises(OSError, match="not a valid git identifier"):
|
||||
get_tokenizer(tokenizer_name, revision="never")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", ["BAAI/bge-base-en"])
|
||||
@pytest.mark.parametrize("n_tokens", [510])
|
||||
def test_special_tokens(tokenizer_name: str, n_tokens: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name, revision="main")
|
||||
|
||||
prompts = "[UNK]" * n_tokens
|
||||
prompt_token_ids = tokenizer.encode(prompts)
|
||||
assert len(prompt_token_ids) == n_tokens + 2
|
||||
@ -6,7 +6,8 @@ from copy import deepcopy
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
|
||||
@ -25,7 +26,7 @@ def test_cached_tokenizer(model_id: str):
|
||||
_check_consistency(unpickled_tokenizer, reference_tokenizer)
|
||||
|
||||
|
||||
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
|
||||
def _check_consistency(target: TokenizerLike, expected: TokenizerLike):
|
||||
assert isinstance(target, type(expected))
|
||||
|
||||
# Cached attributes
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (
|
||||
FastIncrementalDetokenizer,
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
from mistral_common.exceptions import InvalidMessageStructureException
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from vllm.transformers_utils.tokenizers.mistral import (
|
||||
from vllm.tokenizers.mistral import (
|
||||
MistralTokenizer,
|
||||
_prepare_apply_chat_template_tools_and_messages,
|
||||
)
|
||||
@ -308,25 +308,6 @@ class TestMistralTokenizer:
|
||||
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
|
||||
assert mistral_tokenizer.get_added_vocab() == {}
|
||||
|
||||
def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
|
||||
token_ids = (
|
||||
[22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686]
|
||||
)
|
||||
|
||||
assert mistral_tokenizer.encode_one("Hello world !") == token_ids
|
||||
assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids
|
||||
assert (
|
||||
mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1)
|
||||
== token_ids[:-2]
|
||||
)
|
||||
assert (
|
||||
mistral_tokenizer.encode_one(
|
||||
"Hello world !", truncation=False, max_length=1
|
||||
)
|
||||
== token_ids
|
||||
)
|
||||
assert mistral_tokenizer.encode_one("") == []
|
||||
|
||||
def test_encode(self, mistral_tokenizer: MistralTokenizer):
|
||||
token_ids = (
|
||||
[1, 22177, 4304, 2662]
|
||||
36
tests/tokenizers_/test_registry.py
Normal file
36
tests/tokenizers_/test_registry.py
Normal file
@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.tokenizers import TokenizerLike, TokenizerRegistry
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class TestTokenizer(TokenizerLike):
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
|
||||
return TestTokenizer() # type: ignore
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def test_customized_tokenizer():
|
||||
TokenizerRegistry.register(
|
||||
"test_tokenizer",
|
||||
__name__,
|
||||
TestTokenizer.__name__,
|
||||
)
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
|
||||
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
@ -14,8 +14,9 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
|
||||
@ -173,7 +174,7 @@ def test_extract_tool_calls(
|
||||
|
||||
def stream_delta_message_generator(
|
||||
ernie45_tool_parser: Ernie45ToolParser,
|
||||
ernie45_tokenizer: AnyTokenizer,
|
||||
ernie45_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
|
||||
@ -10,8 +10,9 @@ from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -44,7 +45,9 @@ def assert_tool_calls(
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str
|
||||
jamba_tool_parser: JambaToolParser,
|
||||
jamba_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
|
||||
@ -17,8 +17,9 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
|
||||
Qwen3CoderToolParser,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -104,7 +105,7 @@ def assert_tool_calls(
|
||||
|
||||
def stream_delta_message_generator(
|
||||
qwen3_tool_parser,
|
||||
qwen3_tokenizer: AnyTokenizer,
|
||||
qwen3_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
|
||||
@ -15,8 +15,9 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -256,7 +257,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
|
||||
def stream_delta_message_generator(
|
||||
seed_oss_tool_parser: SeedOssToolParser,
|
||||
seed_oss_tokenizer: AnyTokenizer,
|
||||
seed_oss_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
|
||||
@ -13,8 +13,9 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -49,7 +50,7 @@ def assert_tool_calls(
|
||||
|
||||
def stream_delta_message_generator(
|
||||
xlam_tool_parser: xLAMToolParser,
|
||||
xlam_tokenizer: AnyTokenizer,
|
||||
xlam_tokenizer: TokenizerLike,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
|
||||
@ -1,62 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
def test_get_llama3_eos_token():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
import pytest
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128008, 128009]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_patterns,expected_relative_files",
|
||||
[
|
||||
(
|
||||
["*.json", "correct*.txt"],
|
||||
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_list_filtered_repo_files(
|
||||
allow_patterns: list[str], expected_relative_files: list[str]
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
subfolder = path_tmp_dir / "subfolder"
|
||||
subfolder.mkdir()
|
||||
(path_tmp_dir / "json_file.json").touch()
|
||||
(path_tmp_dir / "correct_2.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.jpeg").touch()
|
||||
(subfolder / "correct.txt").touch()
|
||||
(subfolder / "uncorrect_sub.txt").touch()
|
||||
def test_get_blip2_eos_token():
|
||||
model_name = "Salesforce/blip2-opt-2.7b"
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
out_files = sorted(
|
||||
list_filtered_repo_files(
|
||||
tmp_dir, allow_patterns, "revision", "model", "token"
|
||||
)
|
||||
)
|
||||
assert out_files == sorted(expected_relative_files)
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
||||
|
||||
62
tests/transformers_utils/test_repo_utils.py
Normal file
62
tests/transformers_utils/test_repo_utils.py
Normal file
@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_patterns,expected_relative_files",
|
||||
[
|
||||
(
|
||||
["*.json", "correct*.txt"],
|
||||
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_list_filtered_repo_files(
|
||||
allow_patterns: list[str], expected_relative_files: list[str]
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
subfolder = path_tmp_dir / "subfolder"
|
||||
subfolder.mkdir()
|
||||
(path_tmp_dir / "json_file.json").touch()
|
||||
(path_tmp_dir / "correct_2.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.jpeg").touch()
|
||||
(subfolder / "correct.txt").touch()
|
||||
(subfolder / "uncorrect_sub.txt").touch()
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
out_files = sorted(
|
||||
list_filtered_repo_files(
|
||||
tmp_dir, allow_patterns, "revision", "model", "token"
|
||||
)
|
||||
)
|
||||
assert out_files == sorted(expected_relative_files)
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
@ -18,7 +18,7 @@ from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreEvent,
|
||||
EngineCoreEventType,
|
||||
@ -31,7 +31,7 @@ from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
|
||||
def _ref_convert_id_to_token(
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
token_id: int,
|
||||
) -> str:
|
||||
"""Reference impl of logprobs detokenization.
|
||||
|
||||
@ -27,8 +27,8 @@ ALLOWED_FILES = {
|
||||
"vllm/distributed/device_communicators/shm_broadcast.py",
|
||||
"vllm/distributed/device_communicators/shm_object_storage.py",
|
||||
"vllm/utils/hashing.py",
|
||||
"tests/tokenizers_/test_cached_tokenizer.py",
|
||||
"tests/utils_/test_hashing.py",
|
||||
"tests/tokenization/test_cached_tokenizer.py",
|
||||
"benchmarks/kernels/graph_machete_bench.py",
|
||||
"benchmarks/kernels/benchmark_lora.py",
|
||||
"benchmarks/kernels/benchmark_machete.py",
|
||||
|
||||
@ -35,6 +35,7 @@ FILES = [
|
||||
"vllm/multimodal",
|
||||
"vllm/platforms",
|
||||
"vllm/plugins",
|
||||
"vllm/tokenizers",
|
||||
"vllm/transformers_utils",
|
||||
"vllm/triton_utils",
|
||||
"vllm/usage",
|
||||
|
||||
@ -39,7 +39,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
@ -293,7 +293,7 @@ def lora_path_on_disk(lora_path: str) -> str:
|
||||
|
||||
|
||||
# Global cache for LoRA tokenizers.
|
||||
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||
lora_tokenizer_cache: dict[int, TokenizerLike] = {}
|
||||
|
||||
|
||||
def process_image(image: Any) -> Mapping[str, Any]:
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
|
||||
@ -85,7 +85,7 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
async def get_tokenizer(self) -> TokenizerLike:
|
||||
"""Get the tokenizer"""
|
||||
...
|
||||
|
||||
|
||||
@ -49,9 +49,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
@ -536,7 +536,7 @@ def resolve_hf_chat_template(
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
@ -627,7 +627,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
@ -1592,7 +1592,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
@ -1624,7 +1624,7 @@ def parse_chat_messages(
|
||||
def parse_chat_messages_futures(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
|
||||
@ -71,11 +71,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
AnyTokenizer,
|
||||
MistralTokenizer,
|
||||
get_cached_tokenizer,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.collection_utils import as_iter, is_list_of
|
||||
from vllm.utils.counter import Counter
|
||||
@ -350,11 +347,11 @@ class LLM:
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.llm_engine.get_tokenizer()
|
||||
|
||||
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
|
||||
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
|
||||
# While CachedTokenizer is dynamic, have no choice but
|
||||
# compare class name. Misjudgment will arise from
|
||||
# user-defined tokenizer started with 'Cached'
|
||||
@ -1244,7 +1241,7 @@ class LLM:
|
||||
|
||||
def _embedding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
text_1: list[str | TextPrompt | TokensPrompt],
|
||||
text_2: list[str | TextPrompt | TokensPrompt],
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
@ -1276,7 +1273,7 @@ class LLM:
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
data_1: list[str] | list[ScoreContentPartParam],
|
||||
data_2: list[str] | list[ScoreContentPartParam],
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
|
||||
@ -62,8 +62,9 @@ from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import (
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import (
|
||||
MistralTokenizer,
|
||||
maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids,
|
||||
validate_request_params,
|
||||
@ -530,7 +531,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
@ -1296,7 +1297,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> ErrorResponse | ChatCompletionResponse:
|
||||
created_time = int(time.time())
|
||||
@ -1624,7 +1625,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
logprobs: dict[int, Logprob],
|
||||
top_logprobs: int | None,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
should_return_as_token_id: bool,
|
||||
) -> list[ChatCompletionLogProb]:
|
||||
return [
|
||||
@ -1648,7 +1649,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
num_output_top_logprobs: int | None = None,
|
||||
return_as_token_id: bool | None = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
|
||||
@ -221,7 +221,7 @@ class ServingClassification(ClassificationMixin):
|
||||
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
ctx: ServeContext[ClassificationRequest],
|
||||
) -> PoolingParams | ErrorResponse:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
|
||||
@ -33,7 +33,7 @@ from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
|
||||
@ -326,7 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
@ -511,7 +511,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike | None,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
@ -622,7 +622,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike | None,
|
||||
initial_text_offset: int = 0,
|
||||
return_as_token_id: bool | None = None,
|
||||
) -> CompletionLogProbs:
|
||||
@ -642,9 +642,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if should_return_as_token_id:
|
||||
token = f"token_id:{token_id}"
|
||||
else:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
token = tokenizer.decode(token_id)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
|
||||
@ -7,13 +7,14 @@ import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from pydantic import ConfigDict, TypeAdapter
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
@ -96,12 +97,12 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
@ -184,19 +185,19 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
class RequestProcessingMixin(BaseModel):
|
||||
@dataclass(kw_only=True)
|
||||
class RequestProcessingMixin:
|
||||
"""
|
||||
Mixin for request processing,
|
||||
handling prompt preparation and engine input.
|
||||
"""
|
||||
|
||||
request_prompts: Sequence[RequestPrompt] | None = []
|
||||
engine_prompts: list[EngineTokensPrompt] | None = []
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
|
||||
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
|
||||
|
||||
|
||||
class ResponseGenerationMixin(BaseModel):
|
||||
@dataclass(kw_only=True)
|
||||
class ResponseGenerationMixin:
|
||||
"""
|
||||
Mixin for response generation,
|
||||
managing result generators and final batch results.
|
||||
@ -205,54 +206,38 @@ class ResponseGenerationMixin(BaseModel):
|
||||
result_generator: (
|
||||
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
|
||||
) = None
|
||||
final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field(
|
||||
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ServeContext(
|
||||
RequestProcessingMixin,
|
||||
ResponseGenerationMixin,
|
||||
BaseModel,
|
||||
Generic[RequestT],
|
||||
):
|
||||
@dataclass(kw_only=True)
|
||||
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
|
||||
# Shared across all requests
|
||||
request: RequestT
|
||||
raw_request: Request | None = None
|
||||
model_name: str
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: AnyTokenizer | None = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
tokenizer: TokenizerLike | None = None
|
||||
|
||||
|
||||
ClassificationServeContext = ServeContext[ClassificationRequest]
|
||||
@dataclass(kw_only=True)
|
||||
class ClassificationServeContext(ServeContext[ClassificationRequest]):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
|
||||
chat_template: str | None = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption
|
||||
|
||||
|
||||
# Used to resolve the Pydantic error related to
|
||||
# forward reference of MultiModalDataDict in TokensPrompt
|
||||
RequestProcessingMixin.model_rebuild()
|
||||
ServeContext.model_rebuild()
|
||||
ClassificationServeContext.model_rebuild()
|
||||
EmbeddingServeContext.model_rebuild()
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
@ -281,7 +266,7 @@ class OpenAIServing:
|
||||
apply_mistral_chat_template, executor=self._tokenizer_executor
|
||||
)
|
||||
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
|
||||
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
self.input_processor = self.models.input_processor
|
||||
@ -291,7 +276,7 @@ class OpenAIServing:
|
||||
|
||||
def _get_tool_parser(
|
||||
self, tool_parser_name: str | None = None, enable_auto_tools: bool = False
|
||||
) -> Callable[[AnyTokenizer], ToolParser] | None:
|
||||
) -> Callable[[TokenizerLike], ToolParser] | None:
|
||||
"""Get the tool parser based on the name."""
|
||||
parser = None
|
||||
if not enable_auto_tools or tool_parser_name is None:
|
||||
@ -317,7 +302,7 @@ class OpenAIServing:
|
||||
def _get_reasoning_parser(
|
||||
self,
|
||||
reasoning_parser_name: str,
|
||||
) -> Callable[[AnyTokenizer], ReasoningParser] | None:
|
||||
) -> Callable[[TokenizerLike], ReasoningParser] | None:
|
||||
"""Get the reasoning parser based on the name."""
|
||||
parser = None
|
||||
if not reasoning_parser_name:
|
||||
@ -547,7 +532,7 @@ class OpenAIServing:
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer:
|
||||
def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
|
||||
"""
|
||||
Get a Renderer instance with the provided tokenizer.
|
||||
Uses shared async tokenizer pool for efficiency.
|
||||
@ -877,7 +862,7 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
@ -919,7 +904,7 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt_ids: list[int],
|
||||
tokenizer: AnyTokenizer | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> TextTokensPrompt:
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
|
||||
@ -1015,7 +1000,7 @@ class OpenAIServing:
|
||||
async def _tokenize_prompt_input_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
prompt_input: str | list[int],
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
@ -1034,7 +1019,7 @@ class OpenAIServing:
|
||||
async def _tokenize_prompt_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
prompt_inputs: Iterable[str | list[int]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||
@ -1079,7 +1064,7 @@ class OpenAIServing:
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: ChatLikeRequest | ResponsesRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike | None,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
@ -1088,13 +1073,18 @@ class OpenAIServing:
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
documents: list[dict[str, str]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt],
|
||||
]:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
@ -1370,9 +1360,9 @@ class OpenAIServing:
|
||||
@staticmethod
|
||||
def _parse_tool_calls_from_content(
|
||||
request: ResponsesRequest | ChatCompletionRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
enable_auto_tools: bool,
|
||||
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
content: str | None = None,
|
||||
) -> tuple[list[FunctionCall] | None, str | None]:
|
||||
function_calls = list[FunctionCall]()
|
||||
@ -1442,7 +1432,7 @@ class OpenAIServing:
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike | None,
|
||||
return_as_token_id: bool = False,
|
||||
) -> str:
|
||||
if return_as_token_id:
|
||||
@ -1450,6 +1440,12 @@ class OpenAIServing:
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
def _is_model_supported(self, model_name: str | None) -> bool:
|
||||
|
||||
@ -105,7 +105,7 @@ from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -492,7 +492,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
prev_response: ResponsesResponse | None,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
):
|
||||
if request.tools is None or (
|
||||
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
|
||||
@ -563,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
result_generator: AsyncIterator[ConversationContext],
|
||||
context: ConversationContext,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
created_time: int | None = None,
|
||||
) -> ErrorResponse | ResponsesResponse:
|
||||
@ -675,7 +675,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
logprobs: dict[int, SampleLogprob],
|
||||
top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> list[LogprobTopLogprob]:
|
||||
"""Returns the top-k logprobs from the logprobs dictionary."""
|
||||
out = []
|
||||
@ -700,7 +700,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
token_ids: Sequence[int],
|
||||
logprobs: SampleLogprobs | None,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
top_logprobs: int | None = None,
|
||||
) -> list[Logprob]:
|
||||
assert logprobs is not None, "logprobs must be provided"
|
||||
@ -736,7 +736,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
token_ids: Sequence[int],
|
||||
logprobs: SampleLogprobs | None,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
top_logprobs: int | None = None,
|
||||
) -> list[response_text_delta_event.Logprob]:
|
||||
lgs = self._create_response_logprobs(
|
||||
@ -763,7 +763,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
final_output: CompletionOutput,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> list[ResponseOutputItem]:
|
||||
if self.reasoning_parser:
|
||||
try:
|
||||
@ -1135,7 +1135,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
result_generator: AsyncIterator[ConversationContext | None],
|
||||
context: ConversationContext,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
created_time: int,
|
||||
_increment_sequence_number_and_return: Callable[
|
||||
@ -1438,7 +1438,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
result_generator: AsyncIterator[ConversationContext | None],
|
||||
context: ConversationContext,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
created_time: int,
|
||||
_increment_sequence_number_and_return: Callable[
|
||||
@ -1891,7 +1891,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
result_generator: AsyncIterator[ConversationContext | None],
|
||||
context: ConversationContext,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
created_time: int | None = None,
|
||||
) -> AsyncGenerator[StreamingResponsesResponse, None]:
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -60,7 +60,7 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
@ -153,7 +153,7 @@ class ServingScores(OpenAIServing):
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
@ -175,7 +175,7 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
data_1: list[str] | list[ScoreContentPartParam],
|
||||
data_2: list[str] | list[ScoreContentPartParam],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
|
||||
@ -22,7 +22,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -170,7 +170,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
tokenizer: AnyTokenizer
|
||||
tokenizer: TokenizerLike
|
||||
chat_template: str | None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -22,7 +22,7 @@ from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
StructuredOutputsParams,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import import_from_path
|
||||
|
||||
@ -36,7 +36,7 @@ class ToolParser:
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
|
||||
@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV31ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
|
||||
@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepSeekV3ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
|
||||
@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Ernie45ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
"""
|
||||
Ernie thinking model format:
|
||||
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
|
||||
|
||||
@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeModelToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
@ -29,7 +29,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
partial_json_loads,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,7 +44,7 @@ class Granite20bFCToolParser(ToolParser):
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.bot_token = "<function_call>"
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
partial_json_loads,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -42,7 +42,7 @@ class GraniteToolParser(ToolParser):
|
||||
are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
# for granite 3.0, the token `<|tool_call|>`
|
||||
self.bot_token = "<|tool_call|>"
|
||||
|
||||
@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
self.model_tokenizer = tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import consume_space
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HunyuanA13BToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
|
||||
@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
|
||||
@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JambaToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
|
||||
@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KimiK2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
@ -4,11 +4,11 @@
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class LongcatFlashToolParser(Hermes2ProToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.tool_call_start_token: str = "<longcat_tool_call>"
|
||||
|
||||
@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MinimaxM2ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
|
||||
@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MinimaxToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize streaming state for tracking tool call progress
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -46,7 +46,7 @@ class MistralToolCall(ToolCall):
|
||||
return id.isalnum() and len(id) == 9
|
||||
|
||||
|
||||
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
|
||||
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
|
||||
return (
|
||||
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
|
||||
)
|
||||
@ -61,7 +61,7 @@ class MistralToolParser(ToolParser):
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
|
||||
@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
else:
|
||||
AnyTokenizer = object
|
||||
TokenizerLike = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: "AnyTokenizer"):
|
||||
def __init__(self, tokenizer: "TokenizerLike"):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
def extract_tool_calls(
|
||||
|
||||
@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3CoderToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser:
|
||||
|
||||
|
||||
class Qwen3XMLToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.parser = StreamingXMLToolCallParser()
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser):
|
||||
TOOL_CALL_START = "<seed:tool_call>"
|
||||
TOOL_CALL_END = "</seed:tool_call>"
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# --- streaming state ---
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser):
|
||||
TOOL_SEP = "<|tool_sep|>"
|
||||
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
# Explicit state flags for robust streaming
|
||||
|
||||
@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class xLAMToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Initialize state for streaming mode
|
||||
|
||||
@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||
from vllm.inputs.data import TextPrompt as EngineTextPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ class BaseRenderer(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer]
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
|
||||
| None = None,
|
||||
):
|
||||
super().__init__(model_config, tokenizer)
|
||||
@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
return async_tokenizer
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if self.tokenizer is None:
|
||||
if tokenizer is None:
|
||||
raise ValueError("No tokenizer available for text input processing")
|
||||
|
||||
if self.async_tokenizer_pool is None:
|
||||
|
||||
@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
AnyTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import TokenizerLike
|
||||
|
||||
ScoreContentPartParam: TypeAlias = (
|
||||
ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
|
||||
@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False):
|
||||
|
||||
|
||||
def _cosine_similarity(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
tokenizer: TokenizerLike,
|
||||
embed_1: list[PoolingRequestOutput],
|
||||
embed_2: list[PoolingRequestOutput],
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@ -93,7 +89,7 @@ def parse_score_data(
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> tuple[str, str, MultiModalDataDict | None]:
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
@ -118,12 +114,14 @@ def _parse_score_content(
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> _ContentPart | None:
|
||||
if isinstance(data, str):
|
||||
data = ChatCompletionContentPartTextParam(type="text", text=data)
|
||||
part = ChatCompletionContentPartTextParam(type="text", text=data)
|
||||
else:
|
||||
part = data
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
|
||||
parse_res = _parse_chat_message_content_part(
|
||||
data,
|
||||
part,
|
||||
mm_parser,
|
||||
wrap_dicts=False,
|
||||
interleave_strings=False,
|
||||
@ -181,7 +179,7 @@ def post_process_tokens(
|
||||
|
||||
def get_score_prompt(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: str | ScoreContentPartParam,
|
||||
data_2: str | ScoreContentPartParam,
|
||||
|
||||
@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
|
||||
@ -46,7 +46,7 @@ class InputPreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> None:
|
||||
@ -59,7 +59,7 @@ class InputPreprocessor:
|
||||
|
||||
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot pass text prompts when `skip_tokenizer_init` is True"
|
||||
@ -228,11 +228,11 @@ class InputPreprocessor:
|
||||
|
||||
return tokenizer.encode(prompt, **tokenization_kwargs)
|
||||
|
||||
def _get_mm_tokenizer(self) -> AnyTokenizer:
|
||||
def _get_mm_tokenizer(self) -> TokenizerLike:
|
||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||
# while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
return cast(AnyTokenizer, object()) # Dummy
|
||||
return cast(TokenizerLike, object()) # Dummy
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
return tokenizer
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
LogitsProcessor: TypeAlias = (
|
||||
Callable[[list[int], torch.Tensor], torch.Tensor]
|
||||
@ -19,7 +19,7 @@ to sample from."""
|
||||
|
||||
|
||||
def get_bad_words_logits_processors(
|
||||
bad_words: list[str], tokenizer: AnyTokenizer
|
||||
bad_words: list[str], tokenizer: TokenizerLike
|
||||
) -> list[LogitsProcessor]:
|
||||
bad_words_ids: list[list[int]] = list()
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (
|
||||
@ -241,7 +241,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
|
||||
@ -50,7 +50,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
@ -347,7 +347,7 @@ class BaseInternVLProcessor(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
@ -561,7 +561,7 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
|
||||
@ -73,9 +73,9 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
AnyTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens,
|
||||
)
|
||||
@ -284,7 +284,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*args,
|
||||
max_num_tiles: int | None = None,
|
||||
**kwargs,
|
||||
@ -434,7 +434,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
max_num_tiles: int | None = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
@ -645,7 +645,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
tokens_per_frame: list[int],
|
||||
frames_indices: list[int],
|
||||
frame_duration_ms: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
img_start_token_ids: list[int],
|
||||
img_end_token_ids: list[int],
|
||||
img_context_token_ids: list[int],
|
||||
@ -670,7 +670,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
tokens_per_frame (list[int]): number of tokens per frame
|
||||
frames_indices (list[int]): frame indices
|
||||
frame_duration_ms (int): duration of each frame in milliseconds
|
||||
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators
|
||||
tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
|
||||
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
|
||||
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
|
||||
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
|
||||
|
||||
@ -34,8 +34,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.processor import cached_image_processor_from_config
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
@ -203,7 +203,7 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
image_processor: BaseImageProcessorFast,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
|
||||
@ -31,7 +31,7 @@ from vllm.multimodal.processing import (
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
|
||||
@ -79,7 +79,7 @@ class OpenCUAProcessor(Qwen2VLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: dict,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
**kwargs,
|
||||
):
|
||||
image_processor = Qwen2VLImageProcessor(**vision_config)
|
||||
|
||||
@ -59,10 +59,8 @@ from vllm.multimodal.processing import (
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
|
||||
@ -91,7 +91,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
@ -1533,7 +1533,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: dict,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_processor = Tarsier2ImageProcessor(**vision_config)
|
||||
|
||||
@ -47,7 +47,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -282,7 +282,7 @@ class SkyworkR1VProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
|
||||
@ -43,8 +43,8 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -321,7 +321,7 @@ class Step3VLProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -51,10 +51,8 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import init_vllm_registered_model, maybe_prefix
|
||||
|
||||
@ -23,8 +23,9 @@ import torch
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
||||
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
||||
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
|
||||
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
@ -76,7 +77,7 @@ PromptSeq: TypeAlias = str | list[int]
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_encode(
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
text: str,
|
||||
*,
|
||||
add_special_tokens: bool | None = None,
|
||||
@ -86,7 +87,7 @@ def _cached_encode(
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
token_ids: tuple[int, ...],
|
||||
*,
|
||||
skip_special_tokens: bool | None = None,
|
||||
@ -96,14 +97,14 @@ def _cached_decode(
|
||||
)
|
||||
|
||||
|
||||
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
|
||||
def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
|
||||
if isinstance(seq, str):
|
||||
return seq
|
||||
|
||||
return _cached_decode(tokenizer, tuple(seq))
|
||||
|
||||
|
||||
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
|
||||
def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
|
||||
if isinstance(seq, str):
|
||||
return _cached_encode(tokenizer, seq, add_special_tokens=False)
|
||||
|
||||
@ -113,7 +114,7 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
|
||||
class _GetMatchIndex(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
prompt: PromptSeq,
|
||||
start_idx: int = 0,
|
||||
) -> int | None: ...
|
||||
@ -143,7 +144,7 @@ class PromptIndexTargets:
|
||||
"""
|
||||
|
||||
def get_match_index(
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
prompt: PromptSeq,
|
||||
start_idx: int = 0,
|
||||
) -> int | None:
|
||||
@ -199,7 +200,7 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
full: _S
|
||||
"""The full content."""
|
||||
|
||||
is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None
|
||||
is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
|
||||
"""
|
||||
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
|
||||
return a boolean mask of shape `(len(full),)` indicating which positions
|
||||
@ -220,7 +221,7 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
seq: _S,
|
||||
embed_text: str,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
|
||||
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
|
||||
embed_token_ids = encode_tokens(tokenizer, embed_text)
|
||||
token_ids = _seq2tokens(tokenizer, full)
|
||||
|
||||
@ -236,7 +237,7 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
seq: _S,
|
||||
embed_token_id: int,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
|
||||
def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
|
||||
token_ids = _seq2tokens(tokenizer, full)
|
||||
|
||||
return torch.tensor(token_ids) == embed_token_id
|
||||
@ -522,7 +523,7 @@ class ResolvedPromptUpdate:
|
||||
def iter_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
start_idx: int = 0,
|
||||
) -> Generator[PromptTargetMatch]:
|
||||
@ -544,7 +545,7 @@ class ResolvedPromptUpdate:
|
||||
def iter_text_matches(
|
||||
self,
|
||||
prompt: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
start_idx: int = 0,
|
||||
) -> Generator[PromptTargetMatch]:
|
||||
@ -566,7 +567,7 @@ class ResolvedPromptUpdate:
|
||||
def iter_matches(
|
||||
self,
|
||||
prompt: list[int] | str,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
start_idx: int = 0,
|
||||
) -> Generator[PromptTargetMatch]:
|
||||
@ -675,7 +676,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
|
||||
def _find_matches(
|
||||
prompt: _S,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
prev_end_idx: int = 0,
|
||||
current_result: "MultiModalPromptUpdatesApplyResult",
|
||||
@ -740,7 +741,7 @@ def _all_items_found(
|
||||
def _apply_matches(
|
||||
prompt: _S,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
|
||||
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
|
||||
|
||||
@ -806,7 +807,7 @@ def _apply_matches(
|
||||
def apply_token_matches(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
|
||||
"""
|
||||
Apply the updates in `mm_prompt_updates` to `prompt`.
|
||||
@ -823,7 +824,7 @@ def apply_token_matches(
|
||||
def apply_text_matches(
|
||||
prompt: str,
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
|
||||
"""
|
||||
Apply the updates in `mm_prompt_updates` to `prompt`.
|
||||
@ -840,7 +841,7 @@ def apply_text_matches(
|
||||
def _iter_placeholders(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> Iterable[PlaceholderFeaturesInfo]:
|
||||
"""
|
||||
Yield each set of placeholder tokens found in `prompt`.
|
||||
@ -909,7 +910,7 @@ def _iter_placeholders(
|
||||
def find_mm_placeholders(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
|
||||
return dict(full_groupby_modality(it))
|
||||
@ -930,7 +931,7 @@ class InputProcessingContext:
|
||||
model_config: ModelConfig
|
||||
"""The configuration of the model."""
|
||||
|
||||
tokenizer: AnyTokenizer
|
||||
tokenizer: TokenizerLike
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
@overload
|
||||
@ -1146,7 +1147,7 @@ class BaseProcessingInfo:
|
||||
def model_id(self) -> str:
|
||||
return self.ctx.model_config.model
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def get_hf_config(self) -> PretrainedConfig:
|
||||
|
||||
@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .processing import (
|
||||
@ -231,17 +232,20 @@ class MultiModalRegistry:
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> InputProcessingContext:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = cast(TokenizerLike, object())
|
||||
elif tokenizer is None:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
|
||||
return InputProcessingContext(model_config, tokenizer)
|
||||
|
||||
def _create_processing_info(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = model_cls._processor_factory
|
||||
@ -252,7 +256,7 @@ class MultiModalRegistry:
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
|
||||
@ -19,12 +19,12 @@ if TYPE_CHECKING:
|
||||
DeltaMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
else:
|
||||
ChatCompletionRequest = Any
|
||||
DeltaMessage = Any
|
||||
ResponsesRequest = Any
|
||||
AnyTokenizer = Any
|
||||
TokenizerLike = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -37,7 +37,7 @@ class ReasoningParser:
|
||||
It is used to extract reasoning content from the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -43,7 +43,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
"""The token that ends reasoning content."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -37,7 +37,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
|
||||
Reasoning parser for MiniMax M2 model.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
self.end_token_id = self.vocab.get("</think>")
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from functools import cached_property
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
||||
import regex as re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
@ -220,7 +220,7 @@ class Olmo3ReasoningParser(ReasoningParser):
|
||||
token is missing from generation.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
self.think_start = r"<think>"
|
||||
|
||||
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.serial_utils import PydanticMsgspecMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -477,7 +477,7 @@ class SamplingParams(
|
||||
eos_ids.update(self.stop_token_ids)
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
|
||||
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
|
||||
if not self.bad_words:
|
||||
return
|
||||
self._bad_words_token_ids = []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user