mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:55:46 +08:00
feat - add a new endpoint get_tokenizer_info to provide tokenizer/chat-template information (#20575)
Signed-off-by: m-misiura <mmisiura@redhat.com>
This commit is contained in:
parent
1c3198b6c4
commit
18bdcf4113
@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
|||||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
"--max-lora-rank",
|
"--max-lora-rank",
|
||||||
"64",
|
"64",
|
||||||
|
"--enable-tokenizer-info-endpoint",
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -283,3 +284,106 @@ async def test_detokenize(
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
assert response.json() == {"prompt": prompt}
|
assert response.json() == {"prompt": prompt}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name,tokenizer_name",
|
||||||
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
|
)
|
||||||
|
async def test_tokenizer_info_basic(
|
||||||
|
server: RemoteOpenAIServer,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer_name: str,
|
||||||
|
):
|
||||||
|
"""Test basic tokenizer info endpoint functionality."""
|
||||||
|
response = requests.get(server.url_for("tokenizer_info"))
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
assert "tokenizer_class" in result
|
||||||
|
assert isinstance(result["tokenizer_class"], str)
|
||||||
|
assert result["tokenizer_class"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
|
||||||
|
"""Test that the response matches expected schema types."""
|
||||||
|
response = requests.get(server.url_for("tokenizer_info"))
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
field_types = {
|
||||||
|
"add_bos_token": bool,
|
||||||
|
"add_prefix_space": bool,
|
||||||
|
"clean_up_tokenization_spaces": bool,
|
||||||
|
"split_special_tokens": bool,
|
||||||
|
"bos_token": str,
|
||||||
|
"eos_token": str,
|
||||||
|
"pad_token": str,
|
||||||
|
"unk_token": str,
|
||||||
|
"chat_template": str,
|
||||||
|
"errors": str,
|
||||||
|
"model_max_length": int,
|
||||||
|
"additional_special_tokens": list,
|
||||||
|
"added_tokens_decoder": dict,
|
||||||
|
}
|
||||||
|
for field, expected_type in field_types.items():
|
||||||
|
if field in result and result[field] is not None:
|
||||||
|
assert isinstance(
|
||||||
|
result[field],
|
||||||
|
expected_type), (f"{field} should be {expected_type.__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tokenizer_info_added_tokens_structure(
|
||||||
|
server: RemoteOpenAIServer, ):
|
||||||
|
"""Test added_tokens_decoder structure if present."""
|
||||||
|
response = requests.get(server.url_for("tokenizer_info"))
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
added_tokens = result.get("added_tokens_decoder")
|
||||||
|
if added_tokens:
|
||||||
|
for token_id, token_info in added_tokens.items():
|
||||||
|
assert isinstance(token_id, str), "Token IDs should be strings"
|
||||||
|
assert isinstance(token_info, dict), "Token info should be a dict"
|
||||||
|
assert "content" in token_info, "Token info should have content"
|
||||||
|
assert "special" in token_info, (
|
||||||
|
"Token info should have special flag")
|
||||||
|
assert isinstance(token_info["special"],
|
||||||
|
bool), ("Special flag should be boolean")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tokenizer_info_consistency_with_tokenize(
|
||||||
|
server: RemoteOpenAIServer, ):
|
||||||
|
"""Test that tokenizer info is consistent with tokenization endpoint."""
|
||||||
|
info_response = requests.get(server.url_for("tokenizer_info"))
|
||||||
|
info_response.raise_for_status()
|
||||||
|
info = info_response.json()
|
||||||
|
tokenize_response = requests.post(
|
||||||
|
server.url_for("tokenize"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"prompt": "Hello world!"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokenize_response.raise_for_status()
|
||||||
|
tokenize_result = tokenize_response.json()
|
||||||
|
info_max_len = info.get("model_max_length")
|
||||||
|
tokenize_max_len = tokenize_result.get("max_model_len")
|
||||||
|
if info_max_len and tokenize_max_len:
|
||||||
|
assert info_max_len >= tokenize_max_len, (
|
||||||
|
"Info max length should be >= tokenize max length")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
|
||||||
|
"""Test chat template is properly included."""
|
||||||
|
response = requests.get(server.url_for("tokenizer_info"))
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
chat_template = result.get("chat_template")
|
||||||
|
if chat_template:
|
||||||
|
assert isinstance(chat_template,
|
||||||
|
str), ("Chat template should be a string")
|
||||||
|
assert chat_template.strip(), "Chat template should not be empty"
|
||||||
@ -522,6 +522,19 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
|||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_register_tokenizer_info_endpoint(args):
|
||||||
|
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||||
|
if getattr(args, 'enable_tokenizer_info_endpoint', False):
|
||||||
|
|
||||||
|
@router.get("/tokenizer_info")
|
||||||
|
async def get_tokenizer_info(raw_request: Request):
|
||||||
|
"""Get comprehensive tokenizer information."""
|
||||||
|
result = await tokenization(raw_request).get_tokenizer_info()
|
||||||
|
return JSONResponse(content=result.model_dump(),
|
||||||
|
status_code=result.code if isinstance(
|
||||||
|
result, ErrorResponse) else 200)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
async def show_available_models(raw_request: Request):
|
async def show_available_models(raw_request: Request):
|
||||||
handler = models(raw_request)
|
handler = models(raw_request)
|
||||||
@ -1692,6 +1705,7 @@ async def run_server_worker(listen_address,
|
|||||||
uvicorn_kwargs['log_config'] = log_config
|
uvicorn_kwargs['log_config'] = log_config
|
||||||
|
|
||||||
async with build_async_engine_client(args, client_config) as engine_client:
|
async with build_async_engine_client(args, client_config) as engine_client:
|
||||||
|
maybe_register_tokenizer_info_endpoint(args)
|
||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
vllm_config = await engine_client.get_vllm_config()
|
vllm_config = await engine_client.get_vllm_config()
|
||||||
|
|||||||
@ -182,6 +182,9 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
|||||||
"""If set to True, enable tracking server_load_metrics in the app state."""
|
"""If set to True, enable tracking server_load_metrics in the app state."""
|
||||||
enable_force_include_usage: bool = False
|
enable_force_include_usage: bool = False
|
||||||
"""If set to True, including usage on every request."""
|
"""If set to True, including usage on every request."""
|
||||||
|
enable_tokenizer_info_endpoint: bool = False
|
||||||
|
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
||||||
|
templates and other tokenizer configuration."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
|||||||
@ -1953,6 +1953,16 @@ class DetokenizeResponse(OpenAIBaseModel):
|
|||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerInfoResponse(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
Response containing tokenizer configuration
|
||||||
|
equivalent to tokenizer_config.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
tokenizer_class: str
|
||||||
|
|
||||||
|
|
||||||
class LoadLoRAAdapterRequest(BaseModel):
|
class LoadLoRAAdapterRequest(BaseModel):
|
||||||
lora_name: str
|
lora_name: str
|
||||||
lora_path: str
|
lora_path: str
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Final, Optional, Union
|
from typing import Any, Final, Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@ -17,11 +17,13 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse)
|
TokenizeResponse,
|
||||||
|
TokenizerInfoResponse)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -155,3 +157,49 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
input_text = prompt_input["prompt"]
|
input_text = prompt_input["prompt"]
|
||||||
|
|
||||||
return DetokenizeResponse(prompt=input_text)
|
return DetokenizeResponse(prompt=input_text)
|
||||||
|
|
||||||
|
async def get_tokenizer_info(
|
||||||
|
self, ) -> Union[TokenizerInfoResponse, ErrorResponse]:
|
||||||
|
"""Get comprehensive tokenizer information."""
|
||||||
|
try:
|
||||||
|
tokenizer = await self.engine_client.get_tokenizer()
|
||||||
|
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
|
||||||
|
return TokenizerInfoResponse(**info)
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Failed to get tokenizer info: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenizerInfo:
|
||||||
|
tokenizer: AnyTokenizer
|
||||||
|
chat_template: Optional[str]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return the tokenizer configuration."""
|
||||||
|
return self._get_tokenizer_config()
|
||||||
|
|
||||||
|
def _get_tokenizer_config(self) -> dict[str, Any]:
|
||||||
|
"""Get tokenizer configuration directly from the tokenizer object."""
|
||||||
|
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
|
||||||
|
|
||||||
|
# Remove file path fields
|
||||||
|
config.pop("vocab_file", None)
|
||||||
|
config.pop("merges_file", None)
|
||||||
|
|
||||||
|
config = self._make_json_serializable(config)
|
||||||
|
config["tokenizer_class"] = type(self.tokenizer).__name__
|
||||||
|
if self.chat_template:
|
||||||
|
config["chat_template"] = self.chat_template
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _make_json_serializable(self, obj):
|
||||||
|
"""Convert any non-JSON-serializable objects to serializable format."""
|
||||||
|
if hasattr(obj, "content"):
|
||||||
|
return obj.content
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [self._make_json_serializable(item) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user