[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Nick Hill 2024-07-18 00:13:30 -07:00 committed by GitHub
parent 8a74c68bd1
commit e2fbaee725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 267 additions and 186 deletions

View File

@ -1,6 +1,5 @@
import os
import pathlib
from dataclasses import dataclass
import pytest
@ -50,23 +49,9 @@ TEST_MESSAGES = [
]
@dataclass
class MockTokenizer:
chat_template = None
@dataclass
class MockServingChat:
tokenizer: MockTokenizer
def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template
template_content = load_chat_template(chat_template=chatml_jinja_path)
# Test assertions
assert template_content is not None
@ -78,22 +63,16 @@ def test_load_chat_template():
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(mock_serving_chat, chat_template=template)
load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=template)
template_content = tokenizer.chat_template
template_content = load_chat_template(chat_template=template)
assert template_content == template
@ -105,8 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=template)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
@ -118,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
result = tokenizer.apply_chat_template(
conversation=mock_request.messages,
tokenize=False,
add_generation_prompt=mock_request.add_generation_prompt)
add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
# Test assertion
assert result == expected_output, (

View File

@ -7,11 +7,11 @@ import jsonschema
import openai # use the official client for correctness check
import pytest
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -21,12 +21,7 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def server(zephyr_lora_files):
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -38,7 +33,7 @@ def server(zephyr_lora_files):
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",

View File

@ -1,6 +1,8 @@
# imports for guided decoding tests
import json
import re
import shutil
from tempfile import TemporaryDirectory
from typing import List
import jsonschema
@ -9,6 +11,7 @@ import pytest
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -30,13 +33,29 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
@pytest.fixture(scope="module")
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -50,7 +69,7 @@ def server(zephyr_lora_files, zephyr_pa_files):
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
@ -111,6 +130,34 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(completion.choices[0].text) >= 1
@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model="zephyr-lora2",
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should appear in tokenized prompt
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters

View File

@ -38,5 +38,4 @@ async def _async_serving_chat_init():
def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
assert serving_completion.chat_template == CHAT_TEMPLATE

View File

@ -5,13 +5,15 @@ import requests
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def server():
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -21,12 +23,25 @@ def server():
"--enforce-eager",
"--max-num-seqs",
"128",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def tokenizer_name(model_name: str,
zephyr_lora_added_tokens_files: str): # noqa: F811
return zephyr_lora_added_tokens_files if (
model_name == "zephyr-lora2") else model_name
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@ -34,16 +49,18 @@ def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str):
model_name: str, tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_special in [False, True]:
prompt = "This is a test prompt."
prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
for add_generation in [False, True]:
for add_special in [False, True]:
@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
"content": "Nice to meet you!"
}, {
"role": "user",
"content": "Can I ask a question?"
"content": "Can I ask a question? vllm1"
}]
prompt = tokenizer.apply_chat_template(
@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")
prompt = "This is a test prompt."
prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(f"CALLING {base_url} FOR {model_name}")
response = requests.post(base_url + "/detokenize",
json={
"model": model_name,

View File

@ -480,11 +480,16 @@ class AsyncLLMEngine:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
async def get_tokenizer(self) -> "PreTrainedTokenizer":
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
def start_background_loop(self) -> None:
"""Start the background loop."""

View File

@ -455,8 +455,11 @@ class LLMEngine:
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":

View File

@ -257,7 +257,8 @@ def run_server(args, llm_engine=None):
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
engine, model_config, served_model_names, args.lora_modules,
args.chat_template)
app.root_path = args.root_path
logger.info("Available routes are:")

View File

@ -5,10 +5,11 @@ from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
ChatCompletionMessageParam)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
@ -29,13 +30,12 @@ class ChatMessageParseResult:
default_factory=list)
def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]):
tokenizer = engine.tokenizer
if chat_template is not None:
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
resolved_chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
@ -46,23 +46,18 @@ def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]):
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(chat_template,
"unicode_escape")
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning("No chat template provided. Chat API will not work.")
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
return resolved_chat_template
@lru_cache(maxsize=None)
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = engine.model_config.hf_config.model_type
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
@ -70,17 +65,14 @@ def _image_token_str(engine: OpenAIServing) -> Optional[str]:
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return engine.tokenizer.decode(
engine.model_config.hf_config.image_token_index)
return tokenizer.decode(model_config.hf_config.image_token_index)
else:
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
text_prompt: str) -> str:
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
@ -89,9 +81,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
def _parse_chat_message_content_parts(
engine: OpenAIServing,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
@ -122,7 +115,7 @@ def _parse_chat_message_content_parts(
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = _image_token_str(engine)
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
@ -130,7 +123,6 @@ def _parse_chat_message_content_parts(
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
engine,
image_token_str=image_token_str,
text_prompt=text_prompt,
)
@ -141,8 +133,9 @@ def _parse_chat_message_content_parts(
def parse_chat_message_content(
engine: OpenAIServing,
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
@ -153,4 +146,5 @@ def parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(engine, role, content)
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -49,7 +50,9 @@ class OpenAIServingChat(OpenAIServing):
lora_modules=lora_modules)
self.response_role = response_role
load_chat_template(self, chat_template)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_chat_completion(
self,
@ -71,11 +74,15 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret
try:
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(self, msg)
chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
@ -84,13 +91,13 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools
]
prompt = self.tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
@ -112,19 +119,19 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
@ -158,12 +165,12 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -175,9 +182,12 @@ class OpenAIServingChat(OpenAIServing):
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
@ -264,6 +274,7 @@ class OpenAIServingChat(OpenAIServing):
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
@ -352,9 +363,13 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
@ -382,6 +397,7 @@ class OpenAIServingChat(OpenAIServing):
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
)
else:
logprobs = None
@ -436,16 +452,14 @@ class OpenAIServingChat(OpenAIServing):
return response
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
bytes=list(token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
@ -454,6 +468,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
@ -463,12 +478,11 @@ class OpenAIServingChat(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
@ -479,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
step_top_logprobs, num_output_top_logprobs,
tokenizer)))
return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -100,20 +101,22 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
@ -122,18 +125,13 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = await self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
tokenizer,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
truncate_prompt_tokens,
**{prompt_arg: prompt})
prompt_ids, prompt_text = prompt_formats
is_tracing_enabled = await self.engine.is_tracing_enabled()
@ -179,7 +177,8 @@ class OpenAIServingCompletion(OpenAIServing):
request_id,
created_time,
model_name,
num_prompts=len(prompts))
num_prompts=len(prompts),
tokenizer=tokenizer)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
@ -191,7 +190,8 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name)
final_res_batch, request, request_id, created_time, model_name,
tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -218,6 +218,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
@ -268,6 +269,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]),
)
else:
@ -336,6 +338,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: PreTrainedTokenizer,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@ -367,6 +370,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
)
else:
@ -404,6 +408,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
@ -417,13 +422,13 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_id, tokenizer)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
@ -436,7 +441,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i

View File

@ -89,14 +89,11 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params()
tokenizer = await self.engine.get_tokenizer()
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request, prompt=prompt)
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = await self._validate_prompt_and_tokenize(
request, tokenizer, **{prompt_arg: prompt})
prompt_ids, prompt_text = prompt_formats
generator = self.engine.encode(

View File

@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated
from vllm.config import ModelConfig
@ -19,7 +20,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__)
@ -52,14 +52,6 @@ class OpenAIServing:
self.model_config = model_config
self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")
self.served_model_names = served_model_names
self.lora_requests = []
@ -154,7 +146,8 @@ class OpenAIServing:
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
EmbeddingRequest, TokenizeRequest,
DetokenizeRequest]
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
if request.model in self.served_model_names:
@ -168,11 +161,12 @@ class OpenAIServing:
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
tokenizer: "PreTrainedTokenizer",
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
@ -181,7 +175,7 @@ class OpenAIServing:
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
@ -200,14 +194,14 @@ class OpenAIServing:
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
input_text = prompt if prompt is not None else tokenizer.decode(
input_ids)
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
@ -245,7 +239,9 @@ class OpenAIServing:
else:
return input_ids, input_text
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
@staticmethod
def _get_decoded_token(logprob: Logprob, token_id: int,
tokenizer: PreTrainedTokenizer) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return self.tokenizer.decode(token_id)
return tokenizer.decode(token_id)

View File

@ -9,7 +9,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
class OpenAIServingTokenization(OpenAIServing):
@ -18,13 +19,15 @@ class OpenAIServingTokenization(OpenAIServing):
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
lora_modules=lora_modules)
load_chat_template(self, chat_template)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
@ -40,20 +43,25 @@ class OpenAIServingTokenization(OpenAIServing):
return self.create_error_response(
"Only one of `prompt` or `messages` should be provided.")
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
if request.messages:
conversation: List[ConversationMessage] = []
for message in request.messages:
conversation.extend(
parse_chat_message_content(self, message).messages)
result = parse_chat_message_content(message, self.model_config,
tokenizer)
conversation.extend(result.messages)
request.prompt = self.tokenizer.apply_chat_template(
request.prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
conversation=conversation,
tokenize=False)
tokenize=False,
chat_template=self.chat_template)
(input_ids, input_text) = self._validate_prompt_and_tokenize(
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
@ -67,7 +75,9 @@ class OpenAIServingTokenization(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, tokenizer, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text)

View File

@ -165,6 +165,12 @@ class Detokenizer:
return len(new_decoded_token_text)
def _replace_none_with_empty(tokens: List[Optional[str]]):
for i, token in enumerate(tokens):
if token is None:
tokens[i] = ""
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
read_offset = len(new_tokens)
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
return new_tokens, prefix_offset, read_offset

View File

@ -88,6 +88,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,