feat - add a new endpoint get_tokenizer_info to provide tokenizer/chat-template information (#20575)

Signed-off-by: m-misiura <mmisiura@redhat.com>
This commit is contained in:
Mac Misiura 2025-07-16 14:52:14 +01:00 committed by GitHub
parent 1c3198b6c4
commit 18bdcf4113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 182 additions and 3 deletions

View File

@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--enable-tokenizer-info-endpoint",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -283,3 +284,106 @@ async def test_detokenize(
response.raise_for_status()
assert response.json() == {"prompt": prompt}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenizer_info_basic(
server: RemoteOpenAIServer,
model_name: str,
tokenizer_name: str,
):
"""Test basic tokenizer info endpoint functionality."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
assert "tokenizer_class" in result
assert isinstance(result["tokenizer_class"], str)
assert result["tokenizer_class"]
@pytest.mark.asyncio
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
"""Test that the response matches expected schema types."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
field_types = {
"add_bos_token": bool,
"add_prefix_space": bool,
"clean_up_tokenization_spaces": bool,
"split_special_tokens": bool,
"bos_token": str,
"eos_token": str,
"pad_token": str,
"unk_token": str,
"chat_template": str,
"errors": str,
"model_max_length": int,
"additional_special_tokens": list,
"added_tokens_decoder": dict,
}
for field, expected_type in field_types.items():
if field in result and result[field] is not None:
assert isinstance(
result[field],
expected_type), (f"{field} should be {expected_type.__name__}")
@pytest.mark.asyncio
async def test_tokenizer_info_added_tokens_structure(
server: RemoteOpenAIServer, ):
"""Test added_tokens_decoder structure if present."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
added_tokens = result.get("added_tokens_decoder")
if added_tokens:
for token_id, token_info in added_tokens.items():
assert isinstance(token_id, str), "Token IDs should be strings"
assert isinstance(token_info, dict), "Token info should be a dict"
assert "content" in token_info, "Token info should have content"
assert "special" in token_info, (
"Token info should have special flag")
assert isinstance(token_info["special"],
bool), ("Special flag should be boolean")
@pytest.mark.asyncio
async def test_tokenizer_info_consistency_with_tokenize(
server: RemoteOpenAIServer, ):
"""Test that tokenizer info is consistent with tokenization endpoint."""
info_response = requests.get(server.url_for("tokenizer_info"))
info_response.raise_for_status()
info = info_response.json()
tokenize_response = requests.post(
server.url_for("tokenize"),
json={
"model": MODEL_NAME,
"prompt": "Hello world!"
},
)
tokenize_response.raise_for_status()
tokenize_result = tokenize_response.json()
info_max_len = info.get("model_max_length")
tokenize_max_len = tokenize_result.get("max_model_len")
if info_max_len and tokenize_max_len:
assert info_max_len >= tokenize_max_len, (
"Info max length should be >= tokenize max length")
@pytest.mark.asyncio
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
"""Test chat template is properly included."""
response = requests.get(server.url_for("tokenizer_info"))
response.raise_for_status()
result = response.json()
chat_template = result.get("chat_template")
if chat_template:
assert isinstance(chat_template,
str), ("Chat template should be a string")
assert chat_template.strip(), "Chat template should not be empty"

View File

@ -522,6 +522,19 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
assert_never(generator)
def maybe_register_tokenizer_info_endpoint(args):
"""Conditionally register the tokenizer info endpoint if enabled."""
if getattr(args, 'enable_tokenizer_info_endpoint', False):
@router.get("/tokenizer_info")
async def get_tokenizer_info(raw_request: Request):
"""Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(content=result.model_dump(),
status_code=result.code if isinstance(
result, ErrorResponse) else 200)
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = models(raw_request)
@ -1692,6 +1705,7 @@ async def run_server_worker(listen_address,
uvicorn_kwargs['log_config'] = log_config
async with build_async_engine_client(args, client_config) as engine_client:
maybe_register_tokenizer_info_endpoint(args)
app = build_app(args)
vllm_config = await engine_client.get_vllm_config()

View File

@ -182,6 +182,9 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
"""If set to True, enable tracking server_load_metrics in the app state."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
templates and other tokenizer configuration."""
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

View File

@ -1953,6 +1953,16 @@ class DetokenizeResponse(OpenAIBaseModel):
prompt: str
class TokenizerInfoResponse(OpenAIBaseModel):
"""
Response containing tokenizer configuration
equivalent to tokenizer_config.json
"""
model_config = ConfigDict(extra="allow")
tokenizer_class: str
class LoadLoRAAdapterRequest(BaseModel):
lora_name: str
lora_path: str

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, Optional, Union
from dataclasses import dataclass
from typing import Any, Final, Optional, Union
import jinja2
from fastapi import Request
@ -17,11 +17,13 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse)
TokenizeResponse,
TokenizerInfoResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ -155,3 +157,49 @@ class OpenAIServingTokenization(OpenAIServing):
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)
async def get_tokenizer_info(
self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
"""Get comprehensive tokenizer information."""
try:
tokenizer = await self.engine_client.get_tokenizer()
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
return TokenizerInfoResponse(**info)
except Exception as e:
return self.create_error_response(
f"Failed to get tokenizer info: {str(e)}")
@dataclass
class TokenizerInfo:
tokenizer: AnyTokenizer
chat_template: Optional[str]
def to_dict(self) -> dict[str, Any]:
"""Return the tokenizer configuration."""
return self._get_tokenizer_config()
def _get_tokenizer_config(self) -> dict[str, Any]:
"""Get tokenizer configuration directly from the tokenizer object."""
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
# Remove file path fields
config.pop("vocab_file", None)
config.pop("merges_file", None)
config = self._make_json_serializable(config)
config["tokenizer_class"] = type(self.tokenizer).__name__
if self.chat_template:
config["chat_template"] = self.chat_template
return config
def _make_json_serializable(self, obj):
"""Convert any non-JSON-serializable objects to serializable format."""
if hasattr(obj, "content"):
return obj.content
elif isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
else:
return obj