[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Nick Hill 2024-07-18 00:13:30 -07:00 committed by GitHub
parent 8a74c68bd1
commit e2fbaee725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 267 additions and 186 deletions

View File

@ -1,6 +1,5 @@
import os import os
import pathlib import pathlib
from dataclasses import dataclass
import pytest import pytest
@ -50,23 +49,9 @@ TEST_MESSAGES = [
] ]
@dataclass
class MockTokenizer:
chat_template = None
@dataclass
class MockServingChat:
tokenizer: MockTokenizer
def test_load_chat_template(): def test_load_chat_template():
# Testing chatml template # Testing chatml template
tokenizer = MockTokenizer() template_content = load_chat_template(chat_template=chatml_jinja_path)
mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template
# Test assertions # Test assertions
assert template_content is not None assert template_content is not None
@ -78,22 +63,16 @@ def test_load_chat_template():
def test_no_load_chat_template_filelike(): def test_no_load_chat_template_filelike():
# Testing chatml template # Testing chatml template
template = "../../examples/does_not_exist" template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"): with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(mock_serving_chat, chat_template=template) load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike(): def test_no_load_chat_template_literallike():
# Testing chatml template # Testing chatml template
template = "{{ messages }}" template = "{{ messages }}"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer) template_content = load_chat_template(chat_template=template)
load_chat_template(mock_serving_chat, chat_template=template)
template_content = tokenizer.chat_template
assert template_content == template assert template_content == template
@ -105,8 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output): expected_output):
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model) tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer) template_content = load_chat_template(chat_template=template)
load_chat_template(mock_serving_chat, chat_template=template)
# Create a mock request object using keyword arguments # Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest( mock_request = ChatCompletionRequest(
@ -118,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
result = tokenizer.apply_chat_template( result = tokenizer.apply_chat_template(
conversation=mock_request.messages, conversation=mock_request.messages,
tokenize=False, tokenize=False,
add_generation_prompt=mock_request.add_generation_prompt) add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
# Test assertion # Test assertion
assert result == expected_output, ( assert result == expected_output, (

View File

@ -7,11 +7,11 @@ import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import torch import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -21,12 +21,7 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_lora_files(): def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def server(zephyr_lora_files):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@ -38,7 +33,7 @@ def server(zephyr_lora_files):
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}", f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",

View File

@ -1,6 +1,8 @@
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import re import re
import shutil
from tempfile import TemporaryDirectory
from typing import List from typing import List
import jsonschema import jsonschema
@ -9,6 +11,7 @@ import pytest
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -30,13 +33,29 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME) return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def zephyr_pa_files(): def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME) return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files): def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@ -50,7 +69,7 @@ def server(zephyr_lora_files, zephyr_pa_files):
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}", f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
@ -111,6 +130,34 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(completion.choices[0].text) >= 1 assert len(completion.choices[0].text) >= 1
@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model="zephyr-lora2",
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should appear in tokenized prompt
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters # first test base model, then test loras, then test prompt adapters

View File

@ -38,5 +38,4 @@ async def _async_serving_chat_init():
def test_async_serving_chat_init(): def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init()) serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None assert serving_completion.chat_template == CHAT_TEMPLATE
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE

View File

@ -5,13 +5,15 @@ import requests
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server(zephyr_lora_added_tokens_files: str): # noqa: F811
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@ -21,12 +23,25 @@ def server():
"--enforce-eager", "--enforce-eager",
"--max-num-seqs", "--max-num-seqs",
"128", "128",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@pytest.fixture(scope="module")
def tokenizer_name(model_name: str,
zephyr_lora_added_tokens_files: str): # noqa: F811
return zephyr_lora_added_tokens_files if (
model_name == "zephyr-lora2") else model_name
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client(server): def client(server):
return server.get_async_client() return server.get_async_client()
@ -34,16 +49,18 @@ def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name,tokenizer_name",
[MODEL_NAME], [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
) )
async def test_tokenize_completions(client: openai.AsyncOpenAI, async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str): model_name: str, tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/") base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_special in [False, True]: for add_special in [False, True]:
prompt = "This is a test prompt." prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special) tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize", response = requests.post(base_url + "/tokenize",
@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name,tokenizer_name",
[MODEL_NAME], [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
) )
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/") base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_generation in [False, True]: for add_generation in [False, True]:
for add_special in [False, True]: for add_special in [False, True]:
@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
"content": "Nice to meet you!" "content": "Nice to meet you!"
}, { }, {
"role": "user", "role": "user",
"content": "Can I ask a question?" "content": "Can I ask a question? vllm1"
}] }]
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name,tokenizer_name",
[MODEL_NAME], [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
) )
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/") base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
prompt = "This is a test prompt." prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False) tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(f"CALLING {base_url} FOR {model_name}")
response = requests.post(base_url + "/detokenize", response = requests.post(base_url + "/detokenize",
json={ json={
"model": model_name, "model": model_name,

View File

@ -480,11 +480,16 @@ class AsyncLLMEngine:
self.set_errored(exc) self.set_errored(exc)
self._request_tracker.propagate_exception(exc) self._request_tracker.propagate_exception(exc)
async def get_tokenizer(self) -> "PreTrainedTokenizer": async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() # type: ignore return await self.engine.get_tokenizer.remote( # type: ignore
else: lora_request)
return self.engine.get_tokenizer()
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""

View File

@ -455,8 +455,11 @@ class LLMEngine:
return self.tokenizer return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(
return self.get_tokenizer_group().get_lora_tokenizer(None) self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self, def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": sequence: Sequence) -> "PreTrainedTokenizer":

View File

@ -257,7 +257,8 @@ def run_server(args, llm_engine=None):
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names) served_model_names)
openai_serving_tokenization = OpenAIServingTokenization( openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template) engine, model_config, served_model_names, args.lora_modules,
args.chat_template)
app.root_path = args.root_path app.root_path = args.root_path
logger.info("Available routes are:") logger.info("Available routes are:")

View File

@ -5,10 +5,11 @@ from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
from openai.types.chat import (ChatCompletionContentPartImageParam, from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam) ChatCompletionContentPartTextParam)
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam, from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
ChatCompletionMessageParam) ChatCompletionMessageParam)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image from vllm.multimodal.utils import async_get_and_parse_image
@ -29,40 +30,34 @@ class ChatMessageParseResult:
default_factory=list) default_factory=list)
def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]): def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
tokenizer = engine.tokenizer if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
if chat_template is not None: # If opening a file fails, set chat template to be args to
try: # ensure we decode so our escape are interpreted correctly
with open(chat_template, "r") as f: resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to logger.info("Using supplied chat template:\n%s", resolved_chat_template)
# ensure we decode so our escape are interpreted correctly return resolved_chat_template
tokenizer.chat_template = codecs.decode(chat_template,
"unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning("No chat template provided. Chat API will not work.")
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _image_token_str(engine: OpenAIServing) -> Optional[str]: def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template) # (similar to chat template)
model_type = engine.model_config.hf_config.model_type model_type = model_config.hf_config.model_type
if model_type == "phi3_v": if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer # Workaround since this token is not defined in the tokenizer
return "<|image_1|>" return "<|image_1|>"
@ -70,17 +65,14 @@ def _image_token_str(engine: OpenAIServing) -> Optional[str]:
# These models do not use image tokens in the prompt # These models do not use image tokens in the prompt
return None return None
if model_type.startswith("llava"): if model_type.startswith("llava"):
return engine.tokenizer.decode( return tokenizer.decode(model_config.hf_config.image_token_index)
engine.model_config.hf_config.image_token_index)
else: raise TypeError("Unknown model type: {model_type}")
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template) # (similar to chat template)
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str, def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
text_prompt: str) -> str:
"""Combine image and text prompts for vision language model""" """Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same # NOTE: For now we assume all model architectures use the same
@ -89,9 +81,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
engine: OpenAIServing,
role: str, role: str,
parts: Iterable[ChatCompletionContentPartParam], parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
@ -122,7 +115,7 @@ def _parse_chat_message_content_parts(
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if mm_futures: if mm_futures:
image_token_str = _image_token_str(engine) image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None: if image_token_str is not None:
if image_token_str in text_prompt: if image_token_str in text_prompt:
logger.warning( logger.warning(
@ -130,7 +123,6 @@ def _parse_chat_message_content_parts(
"Skipping prompt formatting.") "Skipping prompt formatting.")
else: else:
text_prompt = _get_full_image_text_prompt( text_prompt = _get_full_image_text_prompt(
engine,
image_token_str=image_token_str, image_token_str=image_token_str,
text_prompt=text_prompt, text_prompt=text_prompt,
) )
@ -141,8 +133,9 @@ def _parse_chat_message_content_parts(
def parse_chat_message_content( def parse_chat_message_content(
engine: OpenAIServing,
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
role = message["role"] role = message["role"]
content = message.get("content") content = message.get("content")
@ -153,4 +146,5 @@ def parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)] messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[]) return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(engine, role, content) return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -49,7 +50,9 @@ class OpenAIServingChat(OpenAIServing):
lora_modules=lora_modules) lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
load_chat_template(self, chat_template)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_chat_completion( async def create_chat_completion(
self, self,
@ -71,11 +74,15 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret return error_check_ret
try: try:
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages: for msg in request.messages:
chat_parsed_result = parse_chat_message_content(self, msg) chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)
conversation.extend(chat_parsed_result.messages) conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures) mm_futures.extend(chat_parsed_result.mm_futures)
@ -84,13 +91,13 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools tool.model_dump() for tool in request.tools
] ]
prompt = self.tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
tokenize=False, tokenize=False,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
chat_template=request.chat_template, chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
except Exception as e: except Exception as e:
@ -112,19 +119,19 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
try: try:
# Tokenize/detokenize depending on prompt format (string/token list) # Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize( prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request, request,
tokenizer,
prompt=prompt, prompt=prompt,
add_special_tokens=request.add_special_tokens) add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await get_guided_decoding_logits_processor( await
guided_decoding_backend, request, await get_guided_decoding_logits_processor(guided_decoding_backend,
self.engine.get_tokenizer())) request, tokenizer))
if guided_decode_logits_processor: if guided_decode_logits_processor:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
@ -158,12 +165,12 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation) request, result_generator, request_id, conversation, tokenizer)
else: else:
try: try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id, request, raw_request, result_generator, request_id,
conversation) conversation, tokenizer)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@ -175,9 +182,12 @@ class OpenAIServingChat(OpenAIServing):
return request.messages[-1]["role"] return request.messages[-1]["role"]
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, request: ChatCompletionRequest, self,
result_generator: AsyncIterator[RequestOutput], request_id: str, request: ChatCompletionRequest,
conversation: List[ConversationMessage] result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
@ -264,6 +274,7 @@ class OpenAIServingChat(OpenAIServing):
logprobs = self._create_chat_logprobs( logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
) )
else: else:
@ -352,9 +363,13 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def chat_completion_full_generator( async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request], self,
result_generator: AsyncIterator[RequestOutput], request_id: str, request: ChatCompletionRequest,
conversation: List[ConversationMessage] raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
@ -382,6 +397,7 @@ class OpenAIServingChat(OpenAIServing):
token_ids=token_ids, token_ids=token_ids,
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
) )
else: else:
logprobs = None logprobs = None
@ -436,16 +452,14 @@ class OpenAIServingChat(OpenAIServing):
return response return response
def _get_top_logprobs( def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [ return [
ChatCompletionLogProb( ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]), token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0), logprob=max(p[1].logprob, -9999.0),
bytes=list( bytes=list(token.encode("utf-8", errors="replace")))
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
for i, p in enumerate(logprobs.items()) for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs if top_logprobs and i < top_logprobs
] ]
@ -454,6 +468,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs.""" """Create OpenAI-style logprobs."""
@ -463,12 +478,11 @@ class OpenAIServingChat(OpenAIServing):
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id), token=token,
bytes=list( bytes=list(token.encode("utf-8", errors="replace"))))
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
else: else:
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
@ -479,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs[token_id].decoded_token.encode( step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")), "utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs( top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs))) step_top_logprobs, num_output_top_logprobs,
tokenizer)))
return ChatCompletionLogProbs(content=logprobs_content) return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -100,20 +101,22 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = [] generators: List[AsyncIterator[RequestOutput]] = []
try: try:
sampling_params = request.to_sampling_params()
adapter_type, adapter_request = self._maybe_get_adapter(request) adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA': if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter': elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request lora_request, prompt_adapter_request = None, adapter_request
tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
guided_decode_logit_processor = ( guided_decode_logit_processor = (
await get_guided_decoding_logits_processor( await
guided_decoding_backend, request, await get_guided_decoding_logits_processor(guided_decoding_backend,
self.engine.get_tokenizer())) request, tokenizer))
if guided_decode_logit_processor is not None: if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
@ -122,18 +125,13 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_is_tokens, prompts = parse_prompt_format(request.prompt) prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = self._validate_prompt_and_tokenize( prompt_formats = await self._validate_prompt_and_tokenize(
request, request,
prompt_ids=prompt, tokenizer,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens) truncate_prompt_tokens,
else: **{prompt_arg: prompt})
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
is_tracing_enabled = await self.engine.is_tracing_enabled() is_tracing_enabled = await self.engine.is_tracing_enabled()
@ -179,7 +177,8 @@ class OpenAIServingCompletion(OpenAIServing):
request_id, request_id,
created_time, created_time,
model_name, model_name,
num_prompts=len(prompts)) num_prompts=len(prompts),
tokenizer=tokenizer)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
@ -191,7 +190,8 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
response = self.request_output_to_completion_response( response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name) final_res_batch, request, request_id, created_time, model_name,
tokenizer)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@ -218,6 +218,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
assert request.n is not None assert request.n is not None
previous_texts = [""] * request.n * num_prompts previous_texts = [""] * request.n * num_prompts
@ -268,6 +269,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids=delta_token_ids, token_ids=delta_token_ids,
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]), initial_text_offset=len(previous_texts[i]),
) )
else: else:
@ -336,6 +338,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: PreTrainedTokenizer,
) -> CompletionResponse: ) -> CompletionResponse:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
@ -367,6 +370,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs = self._create_completion_logprobs( logprobs = self._create_completion_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
) )
else: else:
@ -404,6 +408,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API.""" """Create logprobs for OpenAI Completion API."""
@ -417,13 +422,13 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
token = self.tokenizer.decode(token_id) token = tokenizer.decode(token_id)
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)
out_top_logprobs.append(None) out_top_logprobs.append(None)
else: else:
token = self._get_decoded_token(step_top_logprobs[token_id], token = self._get_decoded_token(step_top_logprobs[token_id],
token_id) token_id, tokenizer)
token_logprob = max(step_top_logprobs[token_id].logprob, token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0) -9999.0)
out_tokens.append(token) out_tokens.append(token)
@ -436,7 +441,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({ out_top_logprobs.append({
# Convert float("-inf") to the # Convert float("-inf") to the
# JSON-serializable float that OpenAI uses # JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]): self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
max(top_lp[1].logprob, -9999.0) max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items()) for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i if num_output_top_logprobs >= i

View File

@ -89,14 +89,11 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_is_tokens, prompts = parse_prompt_format(request.input) prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
tokenizer = await self.engine.get_tokenizer()
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = self._validate_prompt_and_tokenize( prompt_formats = await self._validate_prompt_and_tokenize(
request, prompt_ids=prompt) request, tokenizer, **{prompt_arg: prompt})
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
generator = self.engine.encode( generator = self.engine.encode(

View File

@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -19,7 +20,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -52,14 +52,6 @@ class OpenAIServing:
self.model_config = model_config self.model_config = model_config
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")
self.served_model_names = served_model_names self.served_model_names = served_model_names
self.lora_requests = [] self.lora_requests = []
@ -154,7 +146,8 @@ class OpenAIServing:
def _maybe_get_adapter( def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest, self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest] EmbeddingRequest, TokenizeRequest,
DetokenizeRequest]
) -> Tuple[Optional[str], Optional[Union[LoRARequest, ) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]: PromptAdapterRequest]]]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
@ -168,11 +161,12 @@ class OpenAIServing:
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( async def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest, DetokenizeRequest, EmbeddingRequest,
TokenizeRequest], TokenizeRequest],
tokenizer: "PreTrainedTokenizer",
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None, prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int, truncate_prompt_tokens: Optional[Annotated[int,
@ -181,7 +175,7 @@ class OpenAIServing:
) -> Tuple[List[int], str]: ) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if prompt and prompt_ids:
raise ValueError( raise ValueError(
"Only one of prompt or prompt_ids should be provided.") "Only one of prompt or prompt_ids should be provided.")
@ -200,14 +194,14 @@ class OpenAIServing:
"truncation": True, "truncation": True,
"max_length": truncate_prompt_tokens, "max_length": truncate_prompt_tokens,
}) })
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None: elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:] input_ids = prompt_ids[-truncate_prompt_tokens:]
else: else:
input_ids = prompt_ids input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode( input_text = prompt if prompt is not None else tokenizer.decode(
prompt_ids) input_ids)
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens # Note: EmbeddingRequest doesn't have max_tokens
@ -245,7 +239,9 @@ class OpenAIServing:
else: else:
return input_ids, input_text return input_ids, input_text
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str: @staticmethod
def _get_decoded_token(logprob: Logprob, token_id: int,
tokenizer: PreTrainedTokenizer) -> str:
if logprob.decoded_token is not None: if logprob.decoded_token is not None:
return logprob.decoded_token return logprob.decoded_token
return self.tokenizer.decode(token_id) return tokenizer.decode(token_id)

View File

@ -9,7 +9,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
class OpenAIServingTokenization(OpenAIServing): class OpenAIServingTokenization(OpenAIServing):
@ -18,13 +19,15 @@ class OpenAIServingTokenization(OpenAIServing):
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None): chat_template: Optional[str] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=None) lora_modules=lora_modules)
load_chat_template(self, chat_template) # If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_tokenize(self, async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse: request: TokenizeRequest) -> TokenizeResponse:
@ -40,20 +43,25 @@ class OpenAIServingTokenization(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"Only one of `prompt` or `messages` should be provided.") "Only one of `prompt` or `messages` should be provided.")
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
if request.messages: if request.messages:
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
for message in request.messages: for message in request.messages:
conversation.extend( result = parse_chat_message_content(message, self.model_config,
parse_chat_message_content(self, message).messages) tokenizer)
conversation.extend(result.messages)
request.prompt = self.tokenizer.apply_chat_template( request.prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
conversation=conversation, conversation=conversation,
tokenize=False) tokenize=False,
chat_template=self.chat_template)
(input_ids, input_text) = self._validate_prompt_and_tokenize( (input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, request,
tokenizer,
prompt=request.prompt, prompt=request.prompt,
add_special_tokens=request.add_special_tokens) add_special_tokens=request.add_special_tokens)
@ -67,7 +75,9 @@ class OpenAIServingTokenization(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize( _, lora_request = self._maybe_get_adapter(request)
request, prompt_ids=request.tokens) tokenizer = await self.engine.get_tokenizer(lora_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, tokenizer, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text) return DetokenizeResponse(prompt=input_text)

View File

@ -165,6 +165,12 @@ class Detokenizer:
return len(new_decoded_token_text) return len(new_decoded_token_text)
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders( def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str], output_tokens: List[str],
@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
read_offset = len(new_tokens) read_offset = len(new_tokens)
prefix_offset = max( prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
return new_tokens, prefix_offset, read_offset return new_tokens, prefix_offset, read_offset

View File

@ -88,6 +88,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode.") "Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,