mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 20:35:32 +08:00
78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from vllm.tokenizers import TokenizerLike
|
|
from vllm.tokenizers.registry import (
|
|
TokenizerRegistry,
|
|
get_tokenizer,
|
|
resolve_tokenizer_args,
|
|
)
|
|
|
|
|
|
class TestTokenizer(TokenizerLike):
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
path_or_repo_id: str | Path,
|
|
*args,
|
|
trust_remote_code: bool = False,
|
|
revision: str | None = None,
|
|
download_dir: str | None = None,
|
|
**kwargs,
|
|
) -> "TestTokenizer":
|
|
return TestTokenizer(path_or_repo_id) # type: ignore
|
|
|
|
def __init__(self, path_or_repo_id: str | Path) -> None:
|
|
super().__init__()
|
|
|
|
self.path_or_repo_id = path_or_repo_id
|
|
|
|
@property
|
|
def bos_token_id(self) -> int:
|
|
return 0
|
|
|
|
@property
|
|
def eos_token_id(self) -> int:
|
|
return 1
|
|
|
|
@property
|
|
def pad_token_id(self) -> int:
|
|
return 2
|
|
|
|
@property
|
|
def is_fast(self) -> bool:
|
|
return True
|
|
|
|
|
|
@pytest.mark.parametrize("runner_type", ["generate", "pooling"])
|
|
def test_resolve_tokenizer_args_idempotent(runner_type):
|
|
tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args(
|
|
"facebook/opt-125m",
|
|
runner_type=runner_type,
|
|
)
|
|
|
|
assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args(
|
|
tokenizer_name, *args, **kwargs
|
|
)
|
|
|
|
|
|
def test_customized_tokenizer():
|
|
TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
|
|
|
|
tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc")
|
|
assert isinstance(tokenizer, TestTokenizer)
|
|
assert tokenizer.path_or_repo_id == "abc"
|
|
assert tokenizer.bos_token_id == 0
|
|
assert tokenizer.eos_token_id == 1
|
|
assert tokenizer.pad_token_id == 2
|
|
|
|
tokenizer = get_tokenizer("abc", tokenizer_mode="test_tokenizer")
|
|
assert isinstance(tokenizer, TestTokenizer)
|
|
assert tokenizer.path_or_repo_id == "abc"
|
|
assert tokenizer.bos_token_id == 0
|
|
assert tokenizer.eos_token_id == 1
|
|
assert tokenizer.pad_token_id == 2
|